@@ -44,7 +44,6 @@ limitations under the License.
4444#include " xla/python/types.h"
4545#include " xla/shape_util.h"
4646#include " xla/xla_data.pb.h"
47- #include " xla/util.h"
4847
4948namespace nb = nanobind;
5049
@@ -82,7 +81,8 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream,
8281 auto arg = args.get <xla::ffi::AnyBuffer>(i);
8382 auto ptype = static_cast <xla::PrimitiveType>(arg->element_type ());
8483 // TODO(b/395428868): Remove this check once we support subbyte types.
85- if (ptype == xla::S1 || ptype == xla::U1) {
84+ if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 ||
85+ ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) {
8686 return xla::ffi::Error (xla::ffi::ErrorCode::kUnimplemented ,
8787 absl::StrFormat (" Unsupported primitive type: %s" ,
8888 PrimitiveType_Name (ptype)));
@@ -112,30 +112,16 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream,
112112 PyTuple_SET_ITEM (host_input_arrays.ptr (), i, nb::none ().inc_ref ().ptr ());
113113 continue ;
114114 }
115+ nb::capsule base (host_input_buffers[i], [](void * ptr) noexcept {
116+ delete[] static_cast <char *>(ptr);
117+ });
115118 auto maybe_dtype = PrimitiveTypeToNbDtype (ptype);
116119 if (!maybe_dtype.ok ()) {
117120 return xla::ffi::Error::Internal (maybe_dtype.status ().ToString ());
118121 }
119122 auto dtype = maybe_dtype.value ();
120123 auto dims = absl::Span<const int64_t >(arg->dimensions ().begin (),
121124 arg->dimensions ().size ());
122- // TODO(b/402422886): Remove this once we form Jax arrays directly instead
123- // of packing/unpacking to/from numpy arrays.
124- // We pass in data using default numpy layout i.e., std::nullopt.
125- size_t bits_per_element = xla::primitive_util::BitWidth (ptype);
126- if (bits_per_element == 2 || bits_per_element == 4 ) {
127- // NOTE(dsuo): FFI argument and return buffers are sized assuming
128- // minimum 1-byte element sizes, even if the data itself is packed.
129- size_t packed_size = arg->size_bytes () * bits_per_element / 8 ;
130- auto buffer = xla::UnpackIntN (
131- bits_per_element, static_cast <const char *>(host_input_buffers[i]),
132- packed_size);
133- delete[] static_cast <char *>(host_input_buffers[i]);
134- host_input_buffers[i] = buffer.release ();
135- }
136- nb::capsule base (host_input_buffers[i], [](void * ptr) noexcept {
137- delete[] static_cast <char *>(ptr);
138- });
139125 auto array = xla::nb_numpy_ndarray (dtype, dims, std::nullopt ,
140126 host_input_buffers[i], base);
141127 array.attr (" flags" ).attr (" writeable" ) = nb::bool_ (false );
@@ -160,7 +146,8 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream,
160146 auto ret = rets.get <xla::ffi::AnyBuffer>(i).value ();
161147 auto ptype = static_cast <xla::PrimitiveType>(ret->element_type ());
162148 // TODO(b/395428868): Remove this check once we support subbyte types.
163- if (ptype == xla::S1 || ptype == xla::U1) {
149+ if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 ||
150+ ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) {
164151 return xla::ffi::Error (xla::ffi::ErrorCode::kUnimplemented ,
165152 absl::StrFormat (" Unsupported primitive type: %s" ,
166153 PrimitiveType_Name (ptype)));
@@ -181,45 +168,32 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream,
181168 }
182169 auto expected_shape = maybe_expected_shape.value ();
183170 auto expected_strides = xla::ByteStridesForShape (expected_shape);
184-
185- const void * data = array.data ();
186- size_t size_bytes = array.size () * array.itemsize ();
187- if (strides != expected_strides) {
188- xla::TransposePlan::Options options;
189- options.elem_size_in_bytes = xla::primitive_util::ByteWidth (ptype);
190- options.dims = absl::Span<int64_t const >(
191- reinterpret_cast <const int64_t *>(array.shape ()), array.ndim ());
192- absl::InlinedVector<int64_t , 4 > reversed_layout;
193- reversed_layout.resize (expected_shape.dimensions ().size ());
194- absl::c_reverse_copy (expected_shape.layout ().minor_to_major (),
195- reversed_layout.begin ());
196- options.permutation = reversed_layout;
197- options.input_layout = xla::TransposePlan::Striding{strides};
198- auto maybe_plan = transpose_cache->cache .GetOrCreate (options);
199- if (!maybe_plan.ok ()) {
200- return xla::ffi::Error::Internal (maybe_plan.status ().ToString ());
201- }
202- auto plan = maybe_plan.value ();
203- void * temp = new char [size_bytes];
204- temp_buffers.push_back (temp);
205- plan->Execute (data, temp);
206- data = temp;
171+ if (strides == expected_strides) {
172+ auto gpu_res =
173+ gpuMemcpyAsync (ret->untyped_data (), array.data (), ret->size_bytes (),
174+ gpuMemcpyHostToDevice, stream);
175+ CHECK_EQ (gpu_res, gpuSuccess) << " Failed to gpuMemcpyAsync" ;
176+ continue ;
207177 }
208-
209- // TODO(b/402422886): Remove this once we form Jax arrays directly instead
210- // of packing/unpacking to/from numpy arrays.
211- std::unique_ptr<char []> buffer;
212- size_t bits_per_element = xla::primitive_util::BitWidth (ptype);
213- if (bits_per_element == 2 || bits_per_element == 4 ) {
214- // NOTE(dsuo): FFI arguments and return buffers are sized assuming
215- // minimum 1-byte element sizes, even if the data itself is packed.
216- buffer = xla::PackIntN (bits_per_element, static_cast <const char *>(data),
217- size_bytes);
218- data = buffer.get ();
219- size_bytes = (size_bytes * bits_per_element) / 8 ;
178+ void * temp = new char [ret->size_bytes ()];
179+ temp_buffers.push_back (temp);
180+ xla::TransposePlan::Options options;
181+ options.elem_size_in_bytes = xla::primitive_util::ByteWidth (ptype);
182+ options.dims = absl::Span<int64_t const >(
183+ reinterpret_cast <const int64_t *>(array.shape ()), array.ndim ());
184+ absl::InlinedVector<int64_t , 4 > reversed_layout;
185+ reversed_layout.resize (expected_shape.dimensions ().size ());
186+ absl::c_reverse_copy (expected_shape.layout ().minor_to_major (),
187+ reversed_layout.begin ());
188+ options.permutation = reversed_layout;
189+ options.input_layout = xla::TransposePlan::Striding{strides};
190+ auto maybe_plan = transpose_cache->cache .GetOrCreate (options);
191+ if (!maybe_plan.ok ()) {
192+ return xla::ffi::Error::Internal (maybe_plan.status ().ToString ());
220193 }
221-
222- auto gpu_res = gpuMemcpyAsync (ret->untyped_data (), data, size_bytes,
194+ auto plan = maybe_plan.value ();
195+ plan->Execute (array.data (), temp);
196+ auto gpu_res = gpuMemcpyAsync (ret->untyped_data (), temp, ret->size_bytes (),
223197 gpuMemcpyHostToDevice, stream);
224198 CHECK_EQ (gpu_res, gpuSuccess) << " Failed to gpuMemcpyAsync" ;
225199 }
0 commit comments