Skip to content

Commit 5341111

Browse files
committed
minor refactoring of contracter
1 parent f47ab66 commit 5341111

File tree

3 files changed

+14
-9
lines changed

3 files changed

+14
-9
lines changed

allegro/nn/_strided/_contract.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,15 +202,19 @@ def forward(
202202
dim=0,
203203
dim_size=scatter_dim_size,
204204
)
205-
x2 = torch.index_select(x2_scatter, 0, idxs)
206205

207206
# === perform TP ===
208207
# convert to strided shape
209208
x1 = x1.reshape(-1, self.mul, self.base_dim1)
210-
x2 = x2.reshape(-1, self.mul, self.base_dim2)
211-
return self._contract(x1, x2)
209+
x2_scatter = x2_scatter.reshape(-1, self.mul, self.base_dim2)
210+
return self._contract_conv(x1, x2_scatter, idxs)
211+
212+
def _contract_conv(
213+
self, x1: torch.Tensor, x2: torch.Tensor, idxs: torch.Tensor
214+
) -> torch.Tensor:
215+
# index select from scattered x2
216+
x2 = torch.index_select(x2, 0, idxs)
212217

213-
def _contract(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
214218
# for shared weights, we can precontract weights and w3j so they can be frozen together
215219
# this is usually advantageous for inference, since the weights would have to be
216220
# multiplied in anyway at some point

allegro/nn/_strided/_cueq_contracter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,5 +138,4 @@ def forward(
138138
.contiguous()
139139
)
140140
else:
141-
x2 = torch.index_select(x2_scatter, 0, idxs)
142-
return self._contract(x1, x2)
141+
return self._contract_conv(x1, x2_scatter, idxs)

allegro/nn/_strided/_flashallegro.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -722,12 +722,14 @@ def __init__(self, **kwargs):
722722
"p_to_nnz_mapper_bwd2", p_to_nnz_mapper_bwd2, persistent=False
723723
)
724724

725-
def _contract(self, x1, x2):
725+
def _contract_conv(self, x1, x2, idxs):
726726
# runtime conditions for triggering kernel code path
727727
if x1.is_cuda and not self.training:
728+
# index select for triton kernel
729+
x2_indexed = torch.index_select(x2, 0, idxs)
728730
return torch.ops.triton.flashallegro_forward(
729731
x1,
730-
x2,
732+
x2_indexed,
731733
self.mode,
732734
self.indptr_fwd,
733735
self.indptr_bwd1,
@@ -752,4 +754,4 @@ def _contract(self, x1, x2):
752754
x1.dtype,
753755
)
754756
else:
755-
return super()._contract(x1, x2)
757+
return super()._contract_conv(x1, x2, idxs)

0 commit comments

Comments
 (0)