File tree Expand file tree Collapse file tree 2 files changed +7
-11
lines changed
modelopt/onnx/quantization Expand file tree Collapse file tree 2 files changed +7
-11
lines changed Original file line number Diff line number Diff line change 9898# supported and working
9999CLIP_MIN = 1e-5
100100
101- def safe_cupy_array (tensor ):
102- """
103- Convert ml_dtypes.int4 tensor to numpy.int8 for CuPy compatibility.
101+ def safe_cupy_array (tensor ):
102+ """Convert ml_dtypes.int4 tensor to numpy.int8 for CuPy compatibility.
104103
105104 In ONNX 1.19, int4 tensors use ml_dtypes.int4 which CuPy doesn't support.
106105 This function converts them to regular numpy.int8 while preserving values.
107-
108106 Args:
109- tensor: numpy array that may have ml_dtypes.int4 dtype
110-
107+ tensor: numpy array that may have ml_dtypes.int4 dtype
111108 Returns:
112- cupy or numpy array (if cupy is not supported) with numpy.int8 dtype if input was ml_dtypes.int4, otherwise unchanged
109+ cupy or numpy array (if cupy is not supported) with numpy.int8 dtype if input was ml_dtypes.int4,
110+ otherwise unchanged
113111 """
114112 try :
115113 import ml_dtypes
114+
116115 if hasattr (tensor , 'dtype' ) and tensor .dtype == ml_dtypes .int4 :
117116 return np .asarray (tensor .astype (numpy .int8 ))
118117 except ImportError :
Original file line number Diff line number Diff line change 2020from functools import partial
2121
2222import torch
23- from _test_utils .import_helper import skip_if_no_libcudnn , skip_if_onnx_version_above_1_18
23+ from _test_utils .import_helper import skip_if_no_libcudnn
2424from _test_utils .onnx_quantization .lib_test_models import SimpleMLP , export_as_onnx , find_init
2525from _test_utils .torch_quantization .quantize_common import get_awq_config
2626
4040
4141
4242def test_int4_awq (tmp_path ):
43- # skip_if_onnx_version_above_1_18()
44-
4543 def _forward_loop (model , dataloader ):
4644 """Forward loop for calibration."""
4745 for data in dataloader :
@@ -115,7 +113,6 @@ def _forward_loop(model, dataloader):
115113
116114
117115def test_int4_awq_cuda (tmp_path ):
118- # skip_if_onnx_version_above_1_18()
119116 skip_if_no_libcudnn ()
120117 block_size = 128
121118
You can’t perform that action at this time.
0 commit comments