Skip to content

Commit 9352d57

Browse files
Input shape from the input dataset
1 parent 213a397 commit 9352d57

File tree

1 file changed

+17
-13
lines changed

1 file changed

+17
-13
lines changed

examples/openvino/aot/aot_openvino_compiler.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -88,19 +88,20 @@ def dump_inputs(calibration_dataset, dest_path):
8888

8989

9090
def main(suite: str, model_name: str, input_shape, quantize: bool, validate: bool, dataset_path: str, device: str):
91-
# Ensure input_shape is a tuple
92-
if isinstance(input_shape, list):
93-
input_shape = tuple(input_shape)
94-
elif not isinstance(input_shape, tuple):
95-
msg = "Input shape must be a list or tuple."
96-
raise ValueError(msg)
97-
98-
calibration_dataset = None
99-
10091
# Load the selected model
10192
model = load_model(suite, model_name)
10293
model = model.eval()
10394

95+
if dataset_path:
96+
calibration_dataset = load_calibration_dataset(dataset_path, suite, model, model_name)
97+
input_shape = tuple(next(iter(calibration_dataset))[0].shape)
98+
print(f"Input shape retrieved from the model config: {input_shape}")
99+
# Ensure input_shape is a tuple
100+
elif isinstance(input_shape, list):
101+
input_shape = tuple(input_shape)
102+
else:
103+
msg = "Input shape must be a list or tuple."
104+
raise ValueError(msg)
104105
# Provide input
105106
example_args = (torch.randn(*input_shape),)
106107

@@ -116,7 +117,6 @@ def main(suite: str, model_name: str, input_shape, quantize: bool, validate: boo
116117
if not dataset_path:
117118
msg = "Quantization requires a calibration dataset."
118119
raise ValueError(msg)
119-
calibration_dataset = load_calibration_dataset(dataset_path, suite, model, model_name)
120120

121121
captured_model = aten_dialect.module()
122122
quantizer = OpenVINOQuantizer()
@@ -154,8 +154,13 @@ def transform(x):
154154
print(f"Model exported and saved as {model_file_name} on {device}.")
155155

156156
if validate:
157-
if calibration_dataset is None:
158-
calibration_dataset = load_calibration_dataset(dataset_path, suite, model, model_name)
157+
if suite == "huggingface":
158+
msg = f"Validation of {suite} models did not support yet."
159+
raise ValueError(msg)
160+
161+
if not dataset_path:
162+
msg = "Validateion requires a calibration dataset."
163+
raise ValueError(msg)
159164

160165
print("Start validation of the quantized model:")
161166
# 1: Dump inputs
@@ -207,7 +212,6 @@ def transform(x):
207212
parser.add_argument(
208213
"--input_shape",
209214
type=eval,
210-
required=True,
211215
help="Input shape for the model as a list or tuple (e.g., [1, 3, 224, 224] or (1, 3, 224, 224)).",
212216
)
213217
parser.add_argument("--quantize", action="store_true", help="Enable model quantization.")

0 commit comments

Comments
 (0)