We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d3041f5 commit b106f8bCopy full SHA for b106f8b
py/torch_tensorrt/ptq.py
@@ -30,12 +30,13 @@ def get_batch(self, names):
30
31
batch = self.dataset_iterator.next()
32
self.current_batch_idx += self.batch_size
33
- # Treat the first element as input and others as targets.
+ inputs_gpu=[]
34
if isinstance(batch, list):
35
- batch = batch[0].to(self.device)
+ for example in batch:
36
+ inputs_gpu.append(example.to(self.device).data_ptr())
37
else:
- batch = batch.to(self.device)
38
- return [batch.data_ptr()]
+ inputs_gpu.append(batch.to(self.device).data_ptr())
39
+ return inputs_gpu
40
41
42
def read_calibration_cache(self):
0 commit comments