9898# supported and working 
9999CLIP_MIN  =  1e-5 
100100
101- 
102101def  safe_cupy_array (tensor ):
103102    """Convert ml_dtypes.int4 tensor to numpy.int8 for CuPy compatibility. 
104103
@@ -294,19 +293,23 @@ def quantize_rtn(
294293                scales [name ] =  np .asnumpy (scales [name ])
295294            gemm_weights_quantized [name ] =  numpy .asarray (qw )
296295        scales  =  reshape_scales_for_per_channel_nodes (scales , block_size , precision_info )
296+         dq_node_attributes  =  {"axis" : 0 , "block_size" : block_size }
297297        qdq .insert_dq_nodes (
298298            graph ,
299299            scales ,
300300            quantized_weights = gemm_weights_quantized ,
301+             attributes = dq_node_attributes ,
301302            precision_info = precision_info ,
302303        )
303304
304305        if  gather_w_map  is  not   None :
305306            assert  gather_s_map  is  not   None , "scale-map not found for quantizable gather nodes" 
307+             gather_dq_node_attributes  =  {"axis" : gather_quantize_axis , "block_size" : gather_block_size }
306308            qdq .insert_dq_nodes (
307309                graph ,
308310                gather_s_map ,
309311                quantized_weights = gather_w_map ,
312+                 attributes = gather_dq_node_attributes ,
310313                precision_info = precision_info ,
311314            )
312315    else :
@@ -322,8 +325,10 @@ def quantize_rtn(
322325            )
323326
324327    logger .info (f"RTN quantization completed in { time .time () -  t_start :.2f}   seconds" )
325-     return  gs .export_onnx (graph )
328+     model  =  gs .export_onnx (graph )
329+     model .ir_version  =  10 
326330
331+     return  model 
327332
328333class  AWQClipHelper :
329334    """AWQ calibration helper class.""" 
0 commit comments