File tree Expand file tree Collapse file tree 1 file changed +21
-0
lines changed
Expand file tree Collapse file tree 1 file changed +21
-0
lines changed Original file line number Diff line number Diff line change @@ -602,6 +602,27 @@ def _get_fake_numpy_namespace(self):
602602 def transform_dag (self , dag ):
603603 from pytato .array import Einsum
604604
605+ # {{{ face_mass: materialize einsum args
606+
607+ def materialize_face_mass_vec (expr ):
608+ if isinstance (expr , pt .Einsum ):
609+ my_tag , = expr .tags_of_type (pt .tags .EinsumInfo )
610+ if my_tag .spec == "ifj,fej,fej->ei" :
611+ mat , jac , vec = expr .args
612+ return pt .einsum ("ifj,fej,fej->ei" ,
613+ mat ,
614+ jac ,
615+ vec .tagged (pt .tags
616+ .ImplementAs (pt .tags .ImplStored ())))
617+ else :
618+ return expr
619+ else :
620+ return expr
621+
622+ dag = pt .transform .map_and_copy (dag , materialize_face_mass_vec )
623+
624+ # }}}
625+
605626 # {{{ materialize
606627
607628 nusers = pt .analysis .get_nusers (dag )
You can’t perform that action at this time.
0 commit comments