Skip to content

Commit 2b224b2

Browse files
authored
Merge pull request #1191 from pytorch/ptq_multi_input
fix: Fix PTQ calibration when there are multiple inputs
2 parents 6e399f8 + b106f8b commit 2b224b2

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

py/torch_tensorrt/ptq.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,13 @@ def get_batch(self, names):
3030

3131
batch = self.dataset_iterator.next()
3232
self.current_batch_idx += self.batch_size
33-
# Treat the first element as input and others as targets.
33+
inputs_gpu=[]
3434
if isinstance(batch, list):
35-
batch = batch[0].to(self.device)
35+
for example in batch:
36+
inputs_gpu.append(example.to(self.device).data_ptr())
3637
else:
37-
batch = batch.to(self.device)
38-
return [batch.data_ptr()]
38+
inputs_gpu.append(batch.to(self.device).data_ptr())
39+
return inputs_gpu
3940

4041

4142
def read_calibration_cache(self):

0 commit comments

Comments
 (0)