Skip to content

Commit b106f8b

Browse files
committed
fix: Fix PTQ calibration when there are multiple inputs
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent d3041f5 commit b106f8b

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

py/torch_tensorrt/ptq.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,13 @@ def get_batch(self, names):
3030

3131
batch = self.dataset_iterator.next()
3232
self.current_batch_idx += self.batch_size
33-
# Treat the first element as input and others as targets.
33+
inputs_gpu=[]
3434
if isinstance(batch, list):
35-
batch = batch[0].to(self.device)
35+
for example in batch:
36+
inputs_gpu.append(example.to(self.device).data_ptr())
3637
else:
37-
batch = batch.to(self.device)
38-
return [batch.data_ptr()]
38+
inputs_gpu.append(batch.to(self.device).data_ptr())
39+
return inputs_gpu
3940

4041

4142
def read_calibration_cache(self):

0 commit comments

Comments
 (0)