Skip to content

Commit 32f4339

Browse files
studyingeugenefracape
authored andcommitted
refactor: simplify forward() permutation logic for compile-friendly execution
What's changed - Replace tensor-based perm construction with list-based version - Add explicit inverse permutation for correctness - Remove TorchScript-specific branches Why - Compile-friendly: torch.compile/AOTAutograd prefer static Python control flow and index lists over device tensor construction inside forward. Replacing torch.tensor([...]), torch.arange(...), and torch.cat(...) with plain Python lists reduces graph breaks and guard complexity, improving compilation stability and cache reuse.
1 parent ff16d32 commit 32f4339

File tree

1 file changed

+8
-21
lines changed

1 file changed

+8
-21
lines changed

compressai/entropy_models/entropy_models.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -474,28 +474,18 @@ def forward(
474474
if training is None:
475475
training = self.training
476476

477-
if not torch.jit.is_scripting():
478-
# x from B x C x ... to C x B x ...
479-
perm = torch.cat(
480-
(
481-
torch.tensor([1, 0], dtype=torch.long, device=x.device),
482-
torch.arange(2, x.ndim, dtype=torch.long, device=x.device),
483-
)
484-
)
485-
inv_perm = perm
486-
else:
487-
raise NotImplementedError()
488-
# TorchScript in 2D for static inference
489-
# Convert to (channels, ... , batch) format
490-
# perm = (1, 2, 3, 0)
491-
# inv_perm = (3, 0, 1, 2)
477+
D = x.dim()
478+
# B C ... -> C B ...
479+
perm = [1, 0] + list(range(2, D))
480+
inv_perm = [0] * D
481+
for i, p in enumerate(perm):
482+
inv_perm[p] = i
492483

493484
x = x.permute(*perm).contiguous()
494485
shape = x.size()
495486
values = x.reshape(x.size(0), 1, -1)
496487

497488
# Add noise or quantize
498-
499489
outputs = self.quantize(
500490
values, "noise" if training else "dequantize", self._get_medians()
501491
)
@@ -510,11 +500,8 @@ def forward(
510500
# likelihood = torch.zeros_like(outputs)
511501

512502
# Convert back to input tensor shape
513-
outputs = outputs.reshape(shape)
514-
outputs = outputs.permute(*inv_perm).contiguous()
515-
516-
likelihood = likelihood.reshape(shape)
517-
likelihood = likelihood.permute(*inv_perm).contiguous()
503+
outputs = outputs.reshape(shape).permute(*inv_perm).contiguous()
504+
likelihood = likelihood.reshape(shape).permute(*inv_perm).contiguous()
518505

519506
return outputs, likelihood
520507

0 commit comments

Comments
 (0)