Skip to content

Commit 6756e43

Browse files
committed
cleanup swap function
1 parent d8c2e90 commit 6756e43

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

more_math/Parser/TensorEvalVisitor.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,7 @@ def visitSfftFunc(self, ctx):
154154

155155
def visitSwapFunc(self, ctx):
156156
tsr = self.visit(ctx.expr(0))
157-
# Evaluate arguments for dim, idx1, idx2. They return full tensors, so we take scalar value.
158-
# We use .data.flatten()[0] to get the scalar safely from any shape
157+
159158
dim_t = self.visit(ctx.expr(1))
160159
idx1_t = self.visit(ctx.expr(2))
161160
idx2_t = self.visit(ctx.expr(3))
@@ -165,15 +164,11 @@ def visitSwapFunc(self, ctx):
165164
j = int(idx2_t.flatten()[0].item())
166165

167166
# Handle negative dim
168-
if dim < 0: dim += tsr.ndim
167+
while dim < 0: dim += tsr.ndim
168+
while i < 0: i += tsr.shape[dim]
169+
while j < 0: j += tsr.shape[dim]
169170

170-
# Create permuted index
171171
indices = torch.arange(tsr.shape[dim], device=tsr.device)
172-
# Swap
173-
# Check bounds? Torch index_select will check bounds or crash.
174-
# Support python style negative indexing for indices
175-
if i < 0: i += tsr.shape[dim]
176-
if j < 0: j += tsr.shape[dim]
177172

178173
val_i = indices[i].clone()
179174
indices[i] = indices[j]

0 commit comments

Comments
 (0)