@@ -2439,27 +2439,7 @@ class Join(COp):
24392439 """
24402440
24412441 check_input = False
2442- __props__ = ("view" ,)
2443-
2444- def __init__ (self , view = - 1 ):
2445- self .view = view
2446- if view != - 1 :
2447- # since the first input is always the axis, the tensors
2448- # start from index 1.
2449- self .view_map = {0 : [1 + view ]}
2450-
2451- def __str__ (self ):
2452- if self .view == - 1 :
2453- return self .__class__ .__name__
2454- else :
2455- classname = self .__class__ .__name__
2456- args = ", " .join (f"{ p } ={ getattr (self , p )!r} " for p in self .__props__ )
2457- return f"{ classname } {{{ args } }}"
2458-
2459- def __setstate__ (self , d ):
2460- self .__dict__ .update (d )
2461- if not hasattr (self , "view" ):
2462- self .view = - 1
2442+ __props__ = ()
24632443
24642444 def make_node (self , axis , * tensors ):
24652445 """
@@ -2476,74 +2456,62 @@ def make_node(self, axis, *tensors):
24762456 if not tensors :
24772457 raise ValueError ("Cannot join an empty list of tensors" )
24782458
2459+ axis = as_tensor_variable (axis )
2460+ if axis .type .dtype not in int_dtypes :
2461+ raise TypeError (f"Axis { axis } must be an integer type." )
2462+ if axis .type .ndim > 0 :
2463+ raise TypeError (f"Axis { axis } must be 0-d." )
2464+
24792465 tensors = [as_tensor_variable (x ) for x in tensors ]
2480- out_dtype = ps .upcast (* [x .type .dtype for x in tensors ])
24812466
2482- if not builtins .all (targs .type .ndim for targs in tensors ):
2467+ if not builtins .all (targs .type .ndim > 0 for targs in tensors ):
24832468 raise TypeError (
24842469 "Join cannot handle arguments of dimension 0."
2485- " Use `stack` to join scalar values."
2470+ " Use `stack` to join scalar values and/or increase rank of scalars ."
24862471 )
24872472
24882473 if len (tensors ) == 1 :
24892474 out_shape = tensors [0 ].type .shape
24902475 else :
2491- # When the axis is fixed, a dimension should be
2492- # broadcastable if at least one of the inputs is
2493- # broadcastable on that dimension (see justification below),
2494- # except for the axis dimension.
2495- # Initialize bcastable all false, and then fill in some trues with
2496- # the loops.
2497-
2498- if not isinstance (axis , int ):
2499- try :
2500- axis = int (get_scalar_constant_value (axis ))
2501- except NotScalarConstantError :
2502- pass
2503-
25042476 ndim = tensors [0 ].type .ndim
2505- if isinstance (axis , int ):
2506- # Basically, broadcastable -> length 1, but the
2507- # converse does not hold. So we permit e.g. T/F/T
2508- # joins, and if they fail at runtime they fail, but if
2509- # they don't then it means that the argument where
2510- # that broadcastable flag was False had length 1 along
2511- # this dimension, and therefore this dimension should
2512- # be broadcastable for the output.
2513-
2514- if axis < - ndim :
2515- raise IndexError (
2516- f"Axis value { axis } is out of range for the given input dimensions"
2517- )
2518- if axis < 0 :
2519- axis += ndim
2520- if axis > ndim - 1 :
2521- raise ValueError (
2522- f"Axis value { axis } is out of range for the given input dimensions"
2523- )
2524- # NOTE: Constant negative axis can no longer be negative at this point.
2525-
2526- in_shapes = [x .type .shape for x in tensors ]
2527- in_ndims = [len (s ) for s in in_shapes ]
2528- if set (in_ndims ) != {ndim }:
2529- raise TypeError (
2530- "Only tensors with the same number of dimensions can be joined."
2531- f" Input ndims were: { in_ndims } ."
2532- )
2477+
2478+ if not builtins .all (x .ndim == ndim for x in tensors ):
2479+ raise TypeError (
2480+ "Only tensors with the same number of dimensions can be joined"
2481+ )
2482+
2483+ try :
2484+ # Note: This is dubious, if a user passed a constant we should propagate it to the inputs
2485+ # Not override it.
2486+ static_axis = int (get_scalar_constant_value (axis ))
2487+ except NotScalarConstantError :
2488+ static_axis = None
2489+
2490+ if static_axis is None :
2491+ # When axis isn't static, we can't canclude anything about output dimension
2492+ # (unless we had some degenerate zero arrays) that can be removed during rewrites.
2493+ # We could also raise errors if any dimensions are pairwise inconsistent across all the axes
2494+ # As no matter the join it would be invalid.
2495+ # However, dynamic axis is so rare that is not worth the trouble
2496+ out_shape = [None ] * ndim
2497+
2498+ else : # We know the axis statically
2499+ static_axis = normalize_axis_index (static_axis , ndim )
2500+ static_shapes = [x .type .shape for x in tensors ]
25332501
25342502 # Determine output shapes from a matrix of input shapes
2535- in_shapes = np .array (in_shapes )
2503+ static_shapes = np .array (static_shapes )
25362504 out_shape = [None ] * ndim
25372505 for d in range (ndim ):
2538- ins = in_shapes [:, d ]
2539- if d == axis :
2540- # Any unknown size along the axis means we can't sum
2506+ ins = static_shapes [:, d ]
2507+ if d == static_axis :
2508+ # Any unknown size along the axis means we can't infer it
25412509 if None in ins :
25422510 out_shape [d ] = None
25432511 else :
25442512 out_shape [d ] = sum (ins )
25452513 else :
2546- inset = set (in_shapes [:, d ])
2514+ inset = set (static_shapes [:, d ])
25472515 # Other dims must match exactly,
25482516 # or if a mix of None and ? the output will be ?
25492517 # otherwise the input shapes are incompatible.
@@ -2553,54 +2521,27 @@ def make_node(self, axis, *tensors):
25532521 (out_shape [d ],) = inset - {None }
25542522 else :
25552523 raise ValueError (
2556- f"all input array dimensions other than the specified `axis` ({ axis } )"
2524+ f"all input array dimensions other than the specified `axis` ({ static_axis } )"
25572525 " must match exactly, or be unknown (None),"
25582526 f" but along dimension { d } , the inputs shapes are incompatible: { ins } "
25592527 )
2560- else :
2561- # When the axis may vary, no dimension can be guaranteed to be
2562- # broadcastable.
2563- out_shape = [None ] * tensors [0 ].type .ndim
2564-
2565- if not builtins .all (x .ndim == len (out_shape ) for x in tensors ):
2566- raise TypeError (
2567- "Only tensors with the same number of dimensions can be joined"
2568- )
2569-
2570- inputs = [as_tensor_variable (axis ), * tensors ]
2571-
2572- if inputs [0 ].type .dtype not in int_dtypes :
2573- raise TypeError (f"Axis value { inputs [0 ]} must be an integer type" )
25742528
2529+ inputs = [axis , * tensors ]
2530+ out_dtype = ps .upcast (* [x .type .dtype for x in tensors ])
25752531 return Apply (self , inputs , [tensor (dtype = out_dtype , shape = out_shape )])
25762532
2577- def perform (self , node , axis_and_tensors , out_ ):
2578- (out ,) = out_
2579- view = self .view
2580- axis , tens = axis_and_tensors [0 ], axis_and_tensors [1 :]
2581- # we check these tensors for being empty.
2582- if (view != - 1 ) and all (
2583- tensor .shape [axis ] == 0 for tensor in tens [0 :view ] + tens [view + 1 :]
2584- ):
2585- out [0 ] = tens [view ]
2586-
2587- else :
2588- ndim = tens [0 ].ndim
2589- if axis < - ndim :
2590- raise IndexError (
2591- f"Join axis { int (axis )} out of bounds [0, { int (ndim )} )"
2592- )
2593-
2594- out [0 ] = np .asarray (
2595- np .concatenate (tens , axis = axis ), dtype = node .outputs [0 ].type .dtype
2596- )
2533+ def perform (self , node , inputs , output_storage ):
2534+ axis , * arrays = inputs
2535+ output_storage [0 ][0 ] = np .concatenate (
2536+ arrays , axis = axis , dtype = node .outputs [0 ].type .dtype
2537+ )
25972538
25982539 def c_code_cache_version (self ):
25992540 return (5 ,)
26002541
26012542 def c_code (self , node , name , inputs , outputs , sub ):
26022543 axis , tens = inputs [0 ], inputs [1 :]
2603- view = self . view
2544+ view = - 1
26042545 non_empty_tensor = tens [view ]
26052546 input_1 = tens [0 ]
26062547 l = len (tens )
@@ -2656,22 +2597,21 @@ def R_op(self, inputs, eval_points):
26562597 return [None ]
26572598 return self .make_node (inputs [0 ], * eval_points [1 :]).outputs
26582599
2659- def grad (self , axis_and_tensors , grads ):
2600+ def L_op (self , inputs , outputs , grads ):
26602601 """The gradient wrt a join op is a `Split`, used to partition
26612602 the gradient along the `axis` which was used for joining.
26622603 """
2663- (gz ,) = grads
2664- axis , tens = axis_and_tensors [0 ], axis_and_tensors [1 :]
2604+ [gz ] = grads
2605+ [out ] = outputs
2606+ axis , * tensors = inputs
26652607
26662608 rval = [grad_undefined (self , 0 , axis )]
2667-
2668- dtypes = [as_tensor_variable (x ).type .dtype for x in tens ]
2669- out_dtype = ps .upcast (* dtypes )
2609+ out_dtype = out .type .dtype
26702610
26712611 if "float" in out_dtype or "complex" in out_dtype :
26722612 # assume that this is differentiable
2673- split = Split ( len ( tens ) )
2674- split_gz = split (gz , axis , stack ([ shape ( x )[ axis ] for x in tens ]) )
2613+ split_sizes = stack ([ shape ( x )[ axis ] for x in tensors ] )
2614+ split_gz = split (gz , split_sizes , n_splits = len ( tensors ), axis = axis )
26752615 # If there is only one split, it might not be in a list.
26762616 if not isinstance (split_gz , list ):
26772617 split_gz = [split_gz ]
@@ -2684,13 +2624,12 @@ def grad(self, axis_and_tensors, grads):
26842624 else specify_broadcastable (
26852625 g , * (ax for (ax , s ) in enumerate (t .type .shape ) if s == 1 )
26862626 )
2687- for t , g in zip (tens , split_gz , strict = True )
2627+ for t , g in zip (tensors , split_gz , strict = True )
26882628 ]
26892629 rval = rval + split_gz
26902630 else :
2691- # the output has integer type, so the gradient through it
2692- # is 0
2693- rval = rval + [t .zeros_like (dtype = config .floatX ) for t in tens ]
2631+ # the output has integer type, so the gradient through it is 0
2632+ rval = rval + [t .zeros_like (dtype = config .floatX ) for t in tensors ]
26942633
26952634 return rval
26962635
@@ -2710,7 +2649,8 @@ def infer_shape(self, fgraph, node, ishapes):
27102649 # An axis < -n_dim or >= ndim would be invalid, but this is
27112650 # not checked here. A `CheckAndRaise` `Op` would be a way of
27122651 # addressing that, but it may disrupt optimizations.
2713- join_dim = switch (ge (node .inputs [0 ], 0 ), node .inputs [0 ], node .inputs [0 ] + n_dim )
2652+ axis = node .inputs [0 ]
2653+ join_dim = switch (ge (axis , 0 ), axis , axis + n_dim )
27142654 out_shapes = []
27152655 for dim in range (n_dim ):
27162656 # we have to deal with 2 possible cases in here :
@@ -2733,7 +2673,7 @@ def infer_shape(self, fgraph, node, ishapes):
27332673 return [tuple (out_shapes )]
27342674
27352675
2736- join_ = Join ()
2676+ _join = Join ()
27372677pprint .assign (Join , printing .FunctionPrinter (["join" ]))
27382678
27392679
@@ -2776,7 +2716,7 @@ def join(axis, *tensors_list):
27762716 if len (tensors_list ) == 1 :
27772717 return tensors_list [0 ]
27782718 else :
2779- return join_ (axis , * tensors_list )
2719+ return _join (axis , * tensors_list )
27802720
27812721
27822722@_vectorize_node .register (Join )
0 commit comments