Skip to content

Commit 7fad03d

Browse files
committed
torch.full apparently not needed
1 parent 6756e43 commit 7fad03d

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

more_math/Parser/TensorEvalVisitor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,10 @@ def visitStepFunc(self, ctx):
255255
# N-argument functions
256256
def visitSMinFunc(self, ctx):
257257
args = [self.visit(e) for e in ctx.expr()]
258-
return torch.full(self.shape, torch.min(torch.stack(args)), device=self.device)
258+
return torch.min(torch.stack(torch.broadcast_tensors(*args)))
259259
def visitSMaxFunc(self, ctx):
260-
args = [torch.reshape(self.visit(e), self.shape) for e in ctx.expr()]
261-
return torch.full(self.shape, torch.max(torch.stack(args)), device=self.device)
260+
args = [self.visit(e) for e in ctx.expr()]
261+
return torch.max(torch.stack(torch.broadcast_tensors(*args)))
262262

263263
def visitTMinFunc(self, ctx):
264264
return torch.minimum(self.visit(ctx.expr(0)),self.visit(ctx.expr(1)))

0 commit comments

Comments
 (0)