Skip to content

Commit d8aae9a

Browse files
kaushikcfdinducer
authored andcommitted
Replace data attribute of DictOfNamedArrays.
1 parent e2a3111 commit d8aae9a

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

pytato/transform/einsum_distributive_law.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def map_dict_of_named_arrays(
377377
self, expr: DictOfNamedArrays, ctx: _EinsumDistributiveLawMapperContext | None
378378
) -> DictOfNamedArrays:
379379
return expr.replace_if_different(
380-
_data=constantdict(
380+
data=constantdict(
381381
{
382382
name: _verify_is_array(self.rec(subexpr, ctx))
383383
for name, subexpr in expr._data.items()

test/test_linalg.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,31 @@ def test_apply_einsum_distributive_law_3():
144144
assert y is apply_distributive_property_to_einsums(y, lambda x: DoNotDistribute())
145145

146146

147+
def test_apply_einsum_distributive_law_4():
148+
from pytato.transform.einsum_distributive_law import (
149+
DoDistribute,
150+
DoNotDistribute,
151+
EinsumDistributiveLawDescriptor,
152+
apply_distributive_property_to_einsums,
153+
)
154+
155+
def how_to_distribute(
156+
expr: pt.Einsum) -> EinsumDistributiveLawDescriptor:
157+
if pt.analysis.is_einsum_similar_to_subscript(
158+
expr, "ij,j->i"):
159+
return DoDistribute(ioperand=1)
160+
else:
161+
return DoNotDistribute()
162+
163+
x1 = pt.make_placeholder("x1", 4, np.float64)
164+
x2 = pt.make_placeholder("x2", 4, np.float64)
165+
A = pt.make_placeholder("A", (10, 4), np.float64)
166+
y = pt.make_dict_of_named_arrays({"y": A @ (x1+x2)})
167+
y_transformed = apply_distributive_property_to_einsums(y, how_to_distribute)
168+
169+
assert y_transformed == pt.make_dict_of_named_arrays({"y": A@x1 + A@x2})
170+
171+
147172
if __name__ == "__main__":
148173
if len(sys.argv) > 1:
149174
exec(sys.argv[1])

0 commit comments

Comments
 (0)