@@ -5506,15 +5506,26 @@ def _operands_to_keys(*operands, num_keys=1):
55065506
55075507def _sort_jvp (primals , tangents , * , dimension , is_stable , num_keys ):
55085508 shape = primals [0 ].shape
5509- iotas = []
5510- for dim , size in enumerate (shape ):
5511- iotas .append (broadcasted_iota (np .int64 , shape , dim ))
55125509 sorted_primals_and_idx = sort_p .bind (
5513- * primals , iotas [dimension ], dimension = dimension ,
5514- is_stable = is_stable , num_keys = num_keys )
5515- idx = tuple (sorted_primals_and_idx [- 1 ] if i == dimension else iotas [i ]
5516- for i in range (len (shape )))
5517- tangents_out = tuple (t if type (t ) is ad_util .Zero else t [idx ] for t in tangents )
5510+ * primals , broadcasted_iota (np .uint64 , shape , dimension ),
5511+ dimension = dimension , is_stable = is_stable , num_keys = num_keys )
5512+ batch_dims = tuple (np .delete (np .arange (len (shape ), dtype = np .int64 ),
5513+ dimension ))
5514+ dnums = slicing .GatherDimensionNumbers (
5515+ offset_dims = (),
5516+ collapsed_slice_dims = (dimension ,),
5517+ start_index_map = (dimension ,),
5518+ operand_batching_dims = batch_dims ,
5519+ start_indices_batching_dims = batch_dims ,
5520+ )
5521+ idx = expand_dims (sorted_primals_and_idx [- 1 ], (len (shape ),))
5522+ gather_idx = partial (
5523+ slicing .gather ,
5524+ start_indices = idx , dimension_numbers = dnums , slice_sizes = (1 ,) * len (shape ),
5525+ mode = slicing .GatherScatterMode .PROMISE_IN_BOUNDS
5526+ )
5527+ tangents_out = [t if type (t ) is ad_util .Zero else gather_idx (t )
5528+ for t in tangents ]
55185529 return tuple (sorted_primals_and_idx [:- 1 ]), tangents_out
55195530
55205531def _sort_batch_rule (batched_args , batch_dims , * , dimension , is_stable , num_keys ):
0 commit comments