Skip to content

Commit a2cd076

Browse files
Merge pull request #674 from analysiscenter/release_fix
fixed torchmodel bug
2 parents f4373bb + 13a5de3 commit a2cd076

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

batchflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,4 @@
2727
from .utils_telegram import TelegramMessage
2828

2929

30-
__version__ = '0.7.6'
30+
__version__ = '0.7.7'

batchflow/models/torch/base.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1510,15 +1510,16 @@ def aggregate_microbatches(self, outputs, chunked_outputs, chunk_sizes, single_o
15101510
result = []
15111511
for i, _ in enumerate(outputs):
15121512
# All tensors for current `output_name`
1513-
chunked_output = [chunk_outputs[i][:chunk_size]
1514-
for chunk_outputs, chunk_size in zip(chunked_outputs, chunk_sizes)]
1513+
chunked_output = [chunk_outputs[i] for chunk_outputs in chunked_outputs]
15151514
if chunked_output[0].size != 1:
15161515
if len(chunked_output) == 1:
15171516
output_ = chunked_output[0]
15181517
elif isinstance(chunked_output[0], np.ndarray):
1519-
output_ = np.concatenate(chunked_output, axis=0)
1518+
output_ = np.concatenate([chunk_output[:chunk_size]
1519+
for chunk_output, chunk_size in zip(chunked_output, chunk_sizes)], axis=0)
15201520
else:
1521-
output_ = torch.cat(chunked_output, dim=0)
1521+
output_ = torch.cat([chunk_output[:chunk_size]
1522+
for chunk_output, chunk_size in zip(chunked_output, chunk_sizes)], dim=0)
15221523
result.append(output_)
15231524
else:
15241525
result.append(np.mean(chunked_output))

0 commit comments

Comments
 (0)