File tree Expand file tree Collapse file tree 2 files changed +7
-5
lines changed
modelopt/onnx/quantization Expand file tree Collapse file tree 2 files changed +7
-5
lines changed Original file line number Diff line number Diff line change @@ -69,8 +69,8 @@ def _export_tensor_proto(tensor: gs.Constant) -> onnx.TensorProto:
6969
7070 vals = tensor .values
7171 if _onnx_supports_int4 () and dtype in [onnx .TensorProto .INT4 , onnx .TensorProto .UINT4 ]:
72- signed = dtype == onnx .TensorProto .INT4
73- if ( signed ) :
72+ signed = dtype == onnx .TensorProto .INT4
73+ if signed :
7474 vals = pack_float32_to_4bit_cpp_based (tensor .values , signed = signed ).astype (np .int8 )
7575 else :
7676 vals = pack_float32_to_4bit_cpp_based (tensor .values , signed = signed ).astype (np .uint8 )
Original file line number Diff line number Diff line change 9898# supported and working
9999CLIP_MIN = 1e-5
100100
101+
101102def safe_cupy_array (tensor ):
102103 """Convert ml_dtypes.int4 tensor to numpy.int8 for CuPy compatibility.
103104
104105 In ONNX 1.19, int4 tensors use ml_dtypes.int4 which CuPy doesn't support.
105106 This function converts them to regular numpy.int8 while preserving values.
107+
106108 Args:
107109 tensor: numpy array that may have ml_dtypes.int4 dtype
108110 Returns:
@@ -111,12 +113,12 @@ def safe_cupy_array(tensor):
111113 """
112114 try :
113115 import ml_dtypes
114-
115- if hasattr (tensor , ' dtype' ) and tensor .dtype == ml_dtypes .int4 :
116+
117+ if hasattr (tensor , " dtype" ) and tensor .dtype == ml_dtypes .int4 :
116118 return np .asarray (tensor .astype (numpy .int8 ))
117119 except ImportError :
118120 pass
119-
121+
120122 return np .asarray (tensor )
121123
122124
You can’t perform that action at this time.
0 commit comments