@@ -2439,27 +2439,7 @@ class Join(COp):
24392439 """
24402440
24412441 check_input = False
2442- __props__ = ("view" ,)
2443-
2444- def __init__ (self , view = - 1 ):
2445- self .view = view
2446- if view != - 1 :
2447- # since the first input is always the axis, the tensors
2448- # start from index 1.
2449- self .view_map = {0 : [1 + view ]}
2450-
2451- def __str__ (self ):
2452- if self .view == - 1 :
2453- return self .__class__ .__name__
2454- else :
2455- classname = self .__class__ .__name__
2456- args = ", " .join (f"{ p } ={ getattr (self , p )!r} " for p in self .__props__ )
2457- return f"{ classname } {{{ args } }}"
2458-
2459- def __setstate__ (self , d ):
2460- self .__dict__ .update (d )
2461- if not hasattr (self , "view" ):
2462- self .view = - 1
2442+ __props__ = ()
24632443
24642444 def make_node (self , axis , * tensors ):
24652445 """
@@ -2476,74 +2456,62 @@ def make_node(self, axis, *tensors):
24762456 if not tensors :
24772457 raise ValueError ("Cannot join an empty list of tensors" )
24782458
2459+ axis = as_tensor_variable (axis )
2460+ if axis .type .dtype not in int_dtypes :
2461+ raise TypeError (f"Axis { axis } must be an integer type." )
2462+ if axis .type .ndim > 0 :
2463+ raise TypeError (f"Axis { axis } must be 0-d." )
2464+
24792465 tensors = [as_tensor_variable (x ) for x in tensors ]
2480- out_dtype = ps .upcast (* [x .type .dtype for x in tensors ])
24812466
2482- if not builtins .all (targs .type .ndim for targs in tensors ):
2467+ if not builtins .all (targs .type .ndim > 0 for targs in tensors ):
24832468 raise TypeError (
24842469 "Join cannot handle arguments of dimension 0."
2485- " Use `stack` to join scalar values."
2470+ " Use `stack` to join scalar values and/or increase rank of scalars ."
24862471 )
24872472
24882473 if len (tensors ) == 1 :
24892474 out_shape = tensors [0 ].type .shape
24902475 else :
2491- # When the axis is fixed, a dimension should be
2492- # broadcastable if at least one of the inputs is
2493- # broadcastable on that dimension (see justification below),
2494- # except for the axis dimension.
2495- # Initialize bcastable all false, and then fill in some trues with
2496- # the loops.
2497-
2498- if not isinstance (axis , int ):
2499- try :
2500- axis = int (get_scalar_constant_value (axis ))
2501- except NotScalarConstantError :
2502- pass
2503-
25042476 ndim = tensors [0 ].type .ndim
2505- if isinstance (axis , int ):
2506- # Basically, broadcastable -> length 1, but the
2507- # converse does not hold. So we permit e.g. T/F/T
2508- # joins, and if they fail at runtime they fail, but if
2509- # they don't then it means that the argument where
2510- # that broadcastable flag was False had length 1 along
2511- # this dimension, and therefore this dimension should
2512- # be broadcastable for the output.
2513-
2514- if axis < - ndim :
2515- raise IndexError (
2516- f"Axis value { axis } is out of range for the given input dimensions"
2517- )
2518- if axis < 0 :
2519- axis += ndim
2520- if axis > ndim - 1 :
2521- raise ValueError (
2522- f"Axis value { axis } is out of range for the given input dimensions"
2523- )
2524- # NOTE: Constant negative axis can no longer be negative at this point.
2525-
2526- in_shapes = [x .type .shape for x in tensors ]
2527- in_ndims = [len (s ) for s in in_shapes ]
2528- if set (in_ndims ) != {ndim }:
2529- raise TypeError (
2530- "Only tensors with the same number of dimensions can be joined."
2531- f" Input ndims were: { in_ndims } ."
2532- )
2477+
2478+ if not builtins .all (x .ndim == ndim for x in tensors ):
2479+ raise TypeError (
2480+ "Only tensors with the same number of dimensions can be joined"
2481+ )
2482+
2483+ try :
2484+ # Note: This is dubious, if a user passed a constant we should propagate it to the inputs
2485+ # Not override it.
2486+ static_axis = int (get_scalar_constant_value (axis ))
2487+ except NotScalarConstantError :
2488+ static_axis = None
2489+
2490+ if static_axis is None :
2491+ # When axis isn't static, we can't canclude anything about output dimension
2492+ # (unless we had some degenerate zero arrays) that can be removed during rewrites.
2493+ # We could also raise errors if any dimensions are pairwise inconsistent across all the axes
2494+ # As no matter the join it would be invalid.
2495+ # However, dynamic axis is so rare that is not worth the trouble
2496+ out_shape = [None ] * ndim
2497+
2498+ else : # We know the axis statically
2499+ static_axis = normalize_axis_index (static_axis , ndim )
2500+ static_shapes = [x .type .shape for x in tensors ]
25332501
25342502 # Determine output shapes from a matrix of input shapes
2535- in_shapes = np .array (in_shapes )
2503+ static_shapes = np .array (static_shapes )
25362504 out_shape = [None ] * ndim
25372505 for d in range (ndim ):
2538- ins = in_shapes [:, d ]
2539- if d == axis :
2540- # Any unknown size along the axis means we can't sum
2506+ ins = static_shapes [:, d ]
2507+ if d == static_axis :
2508+ # Any unknown size along the axis means we can't infer it
25412509 if None in ins :
25422510 out_shape [d ] = None
25432511 else :
25442512 out_shape [d ] = sum (ins )
25452513 else :
2546- inset = set (in_shapes [:, d ])
2514+ inset = set (static_shapes [:, d ])
25472515 # Other dims must match exactly,
25482516 # or if a mix of None and ? the output will be ?
25492517 # otherwise the input shapes are incompatible.
@@ -2553,100 +2521,71 @@ def make_node(self, axis, *tensors):
25532521 (out_shape [d ],) = inset - {None }
25542522 else :
25552523 raise ValueError (
2556- f"all input array dimensions other than the specified `axis` ({ axis } )"
2524+ f"all input array dimensions other than the specified `axis` ({ static_axis } )"
25572525 " must match exactly, or be unknown (None),"
25582526 f" but along dimension { d } , the inputs shapes are incompatible: { ins } "
25592527 )
2560- else :
2561- # When the axis may vary, no dimension can be guaranteed to be
2562- # broadcastable.
2563- out_shape = [None ] * tensors [0 ].type .ndim
2564-
2565- if not builtins .all (x .ndim == len (out_shape ) for x in tensors ):
2566- raise TypeError (
2567- "Only tensors with the same number of dimensions can be joined"
2568- )
2569-
2570- inputs = [as_tensor_variable (axis ), * tensors ]
2571-
2572- if inputs [0 ].type .dtype not in int_dtypes :
2573- raise TypeError (f"Axis value { inputs [0 ]} must be an integer type" )
25742528
2529+ inputs = [axis , * tensors ]
2530+ out_dtype = ps .upcast (* [x .type .dtype for x in tensors ])
25752531 return Apply (self , inputs , [tensor (dtype = out_dtype , shape = out_shape )])
25762532
2577- def perform (self , node , axis_and_tensors , out_ ):
2578- (out ,) = out_
2579- view = self .view
2580- axis , tens = axis_and_tensors [0 ], axis_and_tensors [1 :]
2581- # we check these tensors for being empty.
2582- if (view != - 1 ) and all (
2583- tensor .shape [axis ] == 0 for tensor in tens [0 :view ] + tens [view + 1 :]
2584- ):
2585- out [0 ] = tens [view ]
2586-
2587- else :
2588- ndim = tens [0 ].ndim
2589- if axis < - ndim :
2590- raise IndexError (
2591- f"Join axis { int (axis )} out of bounds [0, { int (ndim )} )"
2592- )
2593-
2594- out [0 ] = np .asarray (
2595- np .concatenate (tens , axis = axis ), dtype = node .outputs [0 ].type .dtype
2596- )
2533+ def perform (self , node , inputs , output_storage ):
2534+ axis , * arrays = inputs
2535+ output_storage [0 ][0 ] = np .concatenate (
2536+ arrays , axis = axis , dtype = node .outputs [0 ].type .dtype
2537+ )
25972538
25982539 def c_code_cache_version (self ):
2599- return (5 ,)
2540+ return (6 ,)
26002541
26012542 def c_code (self , node , name , inputs , outputs , sub ):
2602- axis , tens = inputs [0 ], inputs [1 :]
2603- view = self .view
2604- non_empty_tensor = tens [view ]
2605- input_1 = tens [0 ]
2606- l = len (tens )
2607- (out ,) = outputs
2543+ axis , * arrays = inputs
2544+ [out ] = outputs
2545+ n = len (arrays )
2546+ ndim = node .outputs [0 ].type .ndim
26082547 fail = sub ["fail" ]
2609- adtype = node .inputs [0 ].type .dtype_specs ()[1 ]
26102548
2611- copy_to_list = (
2612- f"""Py_INCREF({ inp } ); PyList_SetItem(list, { i } , (PyObject*){ inp } );"""
2613- for i , inp in enumerate (tens )
2614- )
2549+ # Most times axis is constant, inline it
2550+ # This is safe to do because the hash of the c_code includes the constant signature
2551+ if isinstance (node .inputs [0 ], Constant ):
2552+ static_axis = int (node .inputs [0 ].data )
2553+ static_axis = normalize_axis_index (static_axis , ndim )
2554+ axis_def = f"{ static_axis } ;"
2555+ axis_check = ""
2556+ else :
2557+ axis_dtype = node .inputs [0 ].type .dtype_specs ()[1 ]
2558+ axis_def = f"(({ axis_dtype } *)PyArray_DATA({ axis } ))[0];"
2559+ axis_check = f"""
2560+ if (axis < 0){{
2561+ axis = { ndim } + axis;
2562+ }}
2563+ if (axis >= { ndim } || axis < 0) {{
2564+ PyErr_SetString(PyExc_ValueError, "Join axis is out of bounds");
2565+ { fail }
2566+ }}
2567+ """
26152568
2616- copy_inputs_to_list = "\n " .join (copy_to_list )
2617- n = len (tens )
2569+ copy_arrays_to_tuple = "\n " .join (
2570+ (
2571+ f"""Py_INCREF({ array } ); PyTuple_SetItem(arrays_tuple, { i } , (PyObject*){ array } );"""
2572+ for i , array in enumerate (arrays )
2573+ )
2574+ )
26182575
26192576 code = f"""
2620- int axis = (({ adtype } *)PyArray_DATA({ axis } ))[0];
2621- PyObject* list = PyList_New({ l } );
2622- { copy_inputs_to_list }
2623- int tensors_lens_sum;
2624- if({ view } != -1) {{
2625- tensors_lens_sum = 0;
2626-
2627- for(int i=0; i < { n } ; i++){{
2628- tensors_lens_sum += PyArray_DIM((PyArrayObject *)(PyList_GetItem(list, i)), axis);
2629- }}
2630- tensors_lens_sum -= PyArray_DIM({ non_empty_tensor } , axis);
2631- }}
2632- if({ view } != -1 && tensors_lens_sum == 0) {{
2633- Py_XDECREF({ out } );
2634- Py_INCREF({ non_empty_tensor } );
2635- { out } = { non_empty_tensor } ;
2636- }}else{{
2637- //PyObject* PyArray_Concatenate(PyObject* obj, int axis)
2638- int ndim = PyArray_NDIM({ input_1 } );
2639- if( axis < -ndim ){{
2640- PyErr_Format(PyExc_IndexError,
2641- "Join axis %d out of bounds [0, %d)", axis, ndim);
2642- { fail }
2643- }}
2644- Py_XDECREF({ out } );
2645- { out } = (PyArrayObject *)PyArray_Concatenate(list, axis);
2646- Py_DECREF(list);
2647- if(!{ out } ){{
2648- { fail }
2649- }}
2577+ int axis = { axis_def }
2578+ PyArrayObject* arrays[{ n } ] = {{{ ',' .join (arrays )} }};
2579+
2580+ { axis_check }
2581+
2582+ Py_XDECREF({ out } );
2583+ PyObject* arrays_tuple = PyTuple_New({ n } );
2584+ { copy_arrays_to_tuple }
2585+ { out } = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2586+ Py_DECREF(arrays_tuple);
2587+ if(!{ out } ){{
2588+ { fail }
26502589 }}
26512590 """
26522591 return code
@@ -2656,22 +2595,21 @@ def R_op(self, inputs, eval_points):
26562595 return [None ]
26572596 return self .make_node (inputs [0 ], * eval_points [1 :]).outputs
26582597
2659- def grad (self , axis_and_tensors , grads ):
2598+ def L_op (self , inputs , outputs , grads ):
26602599 """The gradient wrt a join op is a `Split`, used to partition
26612600 the gradient along the `axis` which was used for joining.
26622601 """
2663- (gz ,) = grads
2664- axis , tens = axis_and_tensors [0 ], axis_and_tensors [1 :]
2602+ [gz ] = grads
2603+ [out ] = outputs
2604+ axis , * tensors = inputs
26652605
26662606 rval = [grad_undefined (self , 0 , axis )]
2667-
2668- dtypes = [as_tensor_variable (x ).type .dtype for x in tens ]
2669- out_dtype = ps .upcast (* dtypes )
2607+ out_dtype = out .type .dtype
26702608
26712609 if "float" in out_dtype or "complex" in out_dtype :
26722610 # assume that this is differentiable
2673- split = Split ( len ( tens ) )
2674- split_gz = split (gz , axis , stack ([ shape ( x )[ axis ] for x in tens ]) )
2611+ split_sizes = stack ([ shape ( x )[ axis ] for x in tensors ] )
2612+ split_gz = split (gz , split_sizes , n_splits = len ( tensors ), axis = axis )
26752613 # If there is only one split, it might not be in a list.
26762614 if not isinstance (split_gz , list ):
26772615 split_gz = [split_gz ]
@@ -2684,13 +2622,12 @@ def grad(self, axis_and_tensors, grads):
26842622 else specify_broadcastable (
26852623 g , * (ax for (ax , s ) in enumerate (t .type .shape ) if s == 1 )
26862624 )
2687- for t , g in zip (tens , split_gz , strict = True )
2625+ for t , g in zip (tensors , split_gz , strict = True )
26882626 ]
26892627 rval = rval + split_gz
26902628 else :
2691- # the output has integer type, so the gradient through it
2692- # is 0
2693- rval = rval + [t .zeros_like (dtype = config .floatX ) for t in tens ]
2629+ # the output has integer type, so the gradient through it is 0
2630+ rval = rval + [t .zeros_like (dtype = config .floatX ) for t in tensors ]
26942631
26952632 return rval
26962633
@@ -2710,7 +2647,8 @@ def infer_shape(self, fgraph, node, ishapes):
27102647 # An axis < -n_dim or >= ndim would be invalid, but this is
27112648 # not checked here. A `CheckAndRaise` `Op` would be a way of
27122649 # addressing that, but it may disrupt optimizations.
2713- join_dim = switch (ge (node .inputs [0 ], 0 ), node .inputs [0 ], node .inputs [0 ] + n_dim )
2650+ axis = node .inputs [0 ]
2651+ join_dim = switch (ge (axis , 0 ), axis , axis + n_dim )
27142652 out_shapes = []
27152653 for dim in range (n_dim ):
27162654 # we have to deal with 2 possible cases in here :
@@ -2733,7 +2671,7 @@ def infer_shape(self, fgraph, node, ishapes):
27332671 return [tuple (out_shapes )]
27342672
27352673
2736- join_ = Join ()
2674+ _join = Join ()
27372675pprint .assign (Join , printing .FunctionPrinter (["join" ]))
27382676
27392677
@@ -2776,7 +2714,7 @@ def join(axis, *tensors_list):
27762714 if len (tensors_list ) == 1 :
27772715 return tensors_list [0 ]
27782716 else :
2779- return join_ (axis , * tensors_list )
2717+ return _join (axis , * tensors_list )
27802718
27812719
27822720@_vectorize_node .register (Join )
0 commit comments