Skip to content

Commit 0d8cc9c

Browse files
committed
fix longstanding bug of tensor OP list
1 parent 4ac582c commit 0d8cc9c

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

more_math/Parser/UnifiedMathVisitor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,13 @@ def _bin_op(self, a, b, torch_op, scalar_op):
101101
# one of them is a list and one is tensor
102102
if self._is_tensor(a) and self._is_list(b):
103103
if(a.shape[0]==len(b)):
104-
c = torch.split(a,1)
105-
return torch.cat([self._bin_op(x, y, torch_op, scalar_op) for x,y in zip(a,c)],dim=0)
104+
A = torch.split(a,1)
105+
return torch.cat([self._bin_op(x, y, torch_op, scalar_op) for x,y in zip(A,b)],dim=0)
106106
return torch.cat([self._bin_op(a, x, torch_op, scalar_op) for x in b], dim=0)
107107
if self._is_list(a) and self._is_tensor(b):
108108
if(b.shape[0]==len(a)):
109-
c = torch.split(a,1)
110-
return torch.cat([self._bin_op(x, y, torch_op, scalar_op) for x,y in zip(c,b)],dim=0)
109+
B = torch.split(b,1)
110+
return torch.cat([self._bin_op(x, y, torch_op, scalar_op) for x,y in zip(a,B)],dim=0)
111111
return torch.cat([self._bin_op(x, b, torch_op, scalar_op) for x in a], dim=0)
112112

113113
if self._is_list(a) and not self._is_tensor(b):

more_math/helper_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ def as_tensor(value, shape):
2222
return value.contiguous()
2323
if isinstance(value, (float, int)):
2424
value = (value,)
25-
# If it's a scalar or list, broadcast to the reference shape provided.
26-
return torch.broadcast_to(torch.Tensor(value).to(dtype=torch.float32), shape).contiguous()
25+
return torch.broadcast_to(torch.Tensor(value).to(dtype=torch.float32), shape).contiguous()
26+
return torch.cat(value)
2727

2828

2929
def parse_expr(expr: str):

0 commit comments

Comments
 (0)