Skip to content

Commit 825f281

Browse files
--batch_size
1 parent 9352d57 commit 825f281

File tree

1 file changed

+32
-7
lines changed

1 file changed

+32
-7
lines changed

examples/openvino/aot/aot_openvino_compiler.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def load_model(suite: str, model_name: str):
5151
raise ValueError(msg)
5252

5353

54-
def load_calibration_dataset(dataset_path: str, suite: str, model: torch.nn.Module, model_name: str):
54+
def load_calibration_dataset(dataset_path: str, batch_size: int, suite: str, model: torch.nn.Module, model_name: str):
5555
val_dir = f"{dataset_path}/val"
5656

5757
if suite == "torchvision":
@@ -62,7 +62,7 @@ def load_calibration_dataset(dataset_path: str, suite: str, model: torch.nn.Modu
6262
val_dataset = datasets.ImageFolder(val_dir, transform=transform)
6363

6464
calibration_dataset = torch.utils.data.DataLoader(
65-
val_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True
65+
val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True
6666
)
6767

6868
return calibration_dataset
@@ -77,7 +77,7 @@ def dump_inputs(calibration_dataset, dest_path):
7777
input_files, targets = [], []
7878
for idx, data in enumerate(calibration_dataset):
7979
feature, target = data
80-
targets.append(target)
80+
targets.extend(target)
8181
file_name = f"{dest_path}/input_{idx}_0.raw"
8282
if not isinstance(feature, torch.Tensor):
8383
feature = torch.tensor(feature)
@@ -87,13 +87,22 @@ def dump_inputs(calibration_dataset, dest_path):
8787
return input_files, targets
8888

8989

90-
def main(suite: str, model_name: str, input_shape, quantize: bool, validate: bool, dataset_path: str, device: str):
90+
def main(
91+
suite: str,
92+
model_name: str,
93+
input_shape,
94+
quantize: bool,
95+
validate: bool,
96+
dataset_path: str,
97+
device: str,
98+
batch_size: int,
99+
):
91100
# Load the selected model
92101
model = load_model(suite, model_name)
93102
model = model.eval()
94103

95104
if dataset_path:
96-
calibration_dataset = load_calibration_dataset(dataset_path, suite, model, model_name)
105+
calibration_dataset = load_calibration_dataset(dataset_path, batch_size, suite, model, model_name)
97106
input_shape = tuple(next(iter(calibration_dataset))[0].shape)
98107
print(f"Input shape retrieved from the model config: {input_shape}")
99108
# Ensure input_shape is a tuple
@@ -192,7 +201,7 @@ def transform(x):
192201
predictions = []
193202
for i in range(len(input_files)):
194203
tensor = np.fromfile(out_path / f"output_{i}_0.raw", dtype=np.float32)
195-
predictions.append(torch.argmax(torch.tensor(tensor)))
204+
predictions.extend(torch.tensor(tensor).reshape(-1, 1000).argmax(-1))
196205

197206
acc_top1 = accuracy_score(predictions, targets)
198207
print(f"acc@1: {acc_top1}")
@@ -214,6 +223,13 @@ def transform(x):
214223
type=eval,
215224
help="Input shape for the model as a list or tuple (e.g., [1, 3, 224, 224] or (1, 3, 224, 224)).",
216225
)
226+
parser.add_argument(
227+
"--batch_size",
228+
type=int,
229+
default=1,
230+
help="Batch size for the validation. Default batch_size == 1."
231+
" The dataset length must be evenly divisible by the batch size.",
232+
)
217233
parser.add_argument("--quantize", action="store_true", help="Enable model quantization.")
218234
parser.add_argument(
219235
"--validate",
@@ -232,4 +248,13 @@ def transform(x):
232248

233249
# Run the main function with parsed arguments
234250
with nncf.torch.disable_patching():
235-
main(args.suite, args.model, args.input_shape, args.quantize, args.validate, args.dataset, args.device)
251+
main(
252+
args.suite,
253+
args.model,
254+
args.input_shape,
255+
args.quantize,
256+
args.validate,
257+
args.dataset,
258+
args.device,
259+
args.batch_size,
260+
)

0 commit comments

Comments
 (0)