@@ -164,9 +164,20 @@ def make_node(self, x, y):
164164
165165 x_shape_dict = dict (zip (x .type .dims , x .type .shape ))
166166 y_shape_dict = dict (zip (y .type .dims , y .type .shape ))
167- shape_dict = {** x_shape_dict , ** y_shape_dict }
167+
168+ # Check for dimension size mismatches (concrete only)
169+ for dim in self .dims :
170+ x_shape = x_shape_dict .get (dim , None )
171+ y_shape = y_shape_dict .get (dim , None )
172+ if (
173+ isinstance (x_shape , int )
174+ and isinstance (y_shape , int )
175+ and x_shape != y_shape
176+ ):
177+ raise ValueError (f"Size of dim '{ dim } ' does not match" )
168178
169179 # Determine output dimensions
180+ shape_dict = {** x_shape_dict , ** y_shape_dict }
170181 out_dims = tuple (d for d in shape_dict if d not in self .dims )
171182
172183 # Determine output shape
@@ -231,17 +242,6 @@ def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None):
231242 if d not in union :
232243 raise ValueError (f"Dimension { d } not found in either input" )
233244
234- # Check for dimension size mismatches (concrete only)
235- for dim in intersection :
236- x_idx = x .type .dims .index (dim )
237- y_idx = y .type .dims .index (dim )
238- if (
239- isinstance (x .type .shape [x_idx ], int )
240- and isinstance (y .type .shape [y_idx ], int )
241- and x .type .shape [x_idx ] != y .type .shape [y_idx ]
242- ):
243- raise ValueError (f"Size of dim '{ dim } ' does not match" )
244-
245245 result = XDot (dims = tuple (dim_set ))(x , y )
246246
247247 return result
0 commit comments