@@ -88,19 +88,20 @@ def dump_inputs(calibration_dataset, dest_path):
88
88
89
89
90
90
def main (suite : str , model_name : str , input_shape , quantize : bool , validate : bool , dataset_path : str , device : str ):
91
- # Ensure input_shape is a tuple
92
- if isinstance (input_shape , list ):
93
- input_shape = tuple (input_shape )
94
- elif not isinstance (input_shape , tuple ):
95
- msg = "Input shape must be a list or tuple."
96
- raise ValueError (msg )
97
-
98
- calibration_dataset = None
99
-
100
91
# Load the selected model
101
92
model = load_model (suite , model_name )
102
93
model = model .eval ()
103
94
95
+ if dataset_path :
96
+ calibration_dataset = load_calibration_dataset (dataset_path , suite , model , model_name )
97
+ input_shape = tuple (next (iter (calibration_dataset ))[0 ].shape )
98
+ print (f"Input shape retrieved from the model config: { input_shape } " )
99
+ # Ensure input_shape is a tuple
100
+ elif isinstance (input_shape , list ):
101
+ input_shape = tuple (input_shape )
102
+ else :
103
+ msg = "Input shape must be a list or tuple."
104
+ raise ValueError (msg )
104
105
# Provide input
105
106
example_args = (torch .randn (* input_shape ),)
106
107
@@ -116,7 +117,6 @@ def main(suite: str, model_name: str, input_shape, quantize: bool, validate: boo
116
117
if not dataset_path :
117
118
msg = "Quantization requires a calibration dataset."
118
119
raise ValueError (msg )
119
- calibration_dataset = load_calibration_dataset (dataset_path , suite , model , model_name )
120
120
121
121
captured_model = aten_dialect .module ()
122
122
quantizer = OpenVINOQuantizer ()
@@ -154,8 +154,13 @@ def transform(x):
154
154
print (f"Model exported and saved as { model_file_name } on { device } ." )
155
155
156
156
if validate :
157
- if calibration_dataset is None :
158
- calibration_dataset = load_calibration_dataset (dataset_path , suite , model , model_name )
157
+ if suite == "huggingface" :
158
+ msg = f"Validation of { suite } models did not support yet."
159
+ raise ValueError (msg )
160
+
161
+ if not dataset_path :
162
+ msg = "Validateion requires a calibration dataset."
163
+ raise ValueError (msg )
159
164
160
165
print ("Start validation of the quantized model:" )
161
166
# 1: Dump inputs
@@ -207,7 +212,6 @@ def transform(x):
207
212
parser .add_argument (
208
213
"--input_shape" ,
209
214
type = eval ,
210
- required = True ,
211
215
help = "Input shape for the model as a list or tuple (e.g., [1, 3, 224, 224] or (1, 3, 224, 224))." ,
212
216
)
213
217
parser .add_argument ("--quantize" , action = "store_true" , help = "Enable model quantization." )
0 commit comments