@@ -578,7 +578,7 @@ def _bcoo_transpose_transpose(ct, data, indices, *, permutation: Sequence[int],
578578 raise ValueError ("Cannot transpose with respect to sparse indices" )
579579 assert data_ct .dtype == data .aval .dtype
580580 ct_spinfo = SparseInfo (tuple (spinfo .shape [p ] for p in permutation ))
581- rev_permutation = list (np .argsort (permutation ))
581+ rev_permutation = list (map ( int , np .argsort (permutation ) ))
582582 # TODO(jakevdp) avoid dummy indices?
583583 dummy_indices = jnp .zeros ([1 for i in range (indices .ndim - 2 )] + list (indices .shape [- 2 :]), dtype = int )
584584 data_trans , _ = _bcoo_transpose (data_ct , dummy_indices , permutation = rev_permutation , spinfo = ct_spinfo )
@@ -865,7 +865,7 @@ def _bcoo_dot_general_transpose(ct, lhs_data, lhs_indices, rhs, *, dimension_num
865865 dims : DotDimensionNumbers = ((ans_rhs , rhs_kept ), (ans_batch , rhs_batch ))
866866 lhs_contract_sorted_by_rhs = list (np .take (lhs_contract , np .argsort (rhs_contract )))
867867 permutation = list (lhs_batch ) + lhs_kept + lhs_contract_sorted_by_rhs
868- out_axes = list (np .argsort (permutation ))
868+ out_axes = list (map ( int , np .argsort (permutation ) ))
869869
870870 # Determine whether efficient approach is possible:
871871 placeholder_data = jnp .empty ((lhs_indices .ndim - 2 ) * (1 ,) + (lhs_indices .shape [- 2 ],))
0 commit comments