9898# supported and working
9999CLIP_MIN = 1e-5
100100
101-
102101def safe_cupy_array (tensor ):
103102 """Convert ml_dtypes.int4 tensor to numpy.int8 for CuPy compatibility.
104103
@@ -294,19 +293,23 @@ def quantize_rtn(
294293 scales [name ] = np .asnumpy (scales [name ])
295294 gemm_weights_quantized [name ] = numpy .asarray (qw )
296295 scales = reshape_scales_for_per_channel_nodes (scales , block_size , precision_info )
296+ dq_node_attributes = {"axis" : 0 , "block_size" : block_size }
297297 qdq .insert_dq_nodes (
298298 graph ,
299299 scales ,
300300 quantized_weights = gemm_weights_quantized ,
301+ attributes = dq_node_attributes ,
301302 precision_info = precision_info ,
302303 )
303304
304305 if gather_w_map is not None :
305306 assert gather_s_map is not None , "scale-map not found for quantizable gather nodes"
307+ gather_dq_node_attributes = {"axis" : gather_quantize_axis , "block_size" : gather_block_size }
306308 qdq .insert_dq_nodes (
307309 graph ,
308310 gather_s_map ,
309311 quantized_weights = gather_w_map ,
312+ attributes = gather_dq_node_attributes ,
310313 precision_info = precision_info ,
311314 )
312315 else :
@@ -322,8 +325,10 @@ def quantize_rtn(
322325 )
323326
324327 logger .info (f"RTN quantization completed in { time .time () - t_start :.2f} seconds" )
325- return gs .export_onnx (graph )
328+ model = gs .export_onnx (graph )
329+ model .ir_version = 10
326330
331+ return model
327332
328333class AWQClipHelper :
329334 """AWQ calibration helper class."""
0 commit comments