Skip to content

Commit a099b28

Browse files
danielsuoGoogle-ML-Automation
authored andcommitted
Reverts 735cec1
PiperOrigin-RevId: 744717457
1 parent 412f88e commit a099b28

File tree

6 files changed

+95
-177
lines changed

6 files changed

+95
-177
lines changed

jaxlib/cuda/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,6 @@ cc_library(
637637
"@xla//third_party/python_runtime:headers", # buildcleaner: keep
638638
"@xla//xla:comparison_util",
639639
"@xla//xla:shape_util",
640-
"@xla//xla:util",
641640
"@xla//xla:xla_data_proto_cc",
642641
"@xla//xla/ffi:ffi_api",
643642
"@xla//xla/ffi/api:ffi",

jaxlib/gpu/py_client_gpu.cc

Lines changed: 31 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ limitations under the License.
4444
#include "xla/python/types.h"
4545
#include "xla/shape_util.h"
4646
#include "xla/xla_data.pb.h"
47-
#include "xla/util.h"
4847

4948
namespace nb = nanobind;
5049

@@ -82,7 +81,8 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream,
8281
auto arg = args.get<xla::ffi::AnyBuffer>(i);
8382
auto ptype = static_cast<xla::PrimitiveType>(arg->element_type());
8483
// TODO(b/395428868): Remove this check once we support subbyte types.
85-
if (ptype == xla::S1 || ptype == xla::U1) {
84+
if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 ||
85+
ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) {
8686
return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented,
8787
absl::StrFormat("Unsupported primitive type: %s",
8888
PrimitiveType_Name(ptype)));
@@ -112,30 +112,16 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream,
112112
PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr());
113113
continue;
114114
}
115+
nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept {
116+
delete[] static_cast<char*>(ptr);
117+
});
115118
auto maybe_dtype = PrimitiveTypeToNbDtype(ptype);
116119
if (!maybe_dtype.ok()) {
117120
return xla::ffi::Error::Internal(maybe_dtype.status().ToString());
118121
}
119122
auto dtype = maybe_dtype.value();
120123
auto dims = absl::Span<const int64_t>(arg->dimensions().begin(),
121124
arg->dimensions().size());
122-
// TODO(b/402422886): Remove this once we form Jax arrays directly instead
123-
// of packing/unpacking to/from numpy arrays.
124-
// We pass in data using default numpy layout i.e., std::nullopt.
125-
size_t bits_per_element = xla::primitive_util::BitWidth(ptype);
126-
if (bits_per_element == 2 || bits_per_element == 4) {
127-
// NOTE(dsuo): FFI argument and return buffers are sized assuming
128-
// minimum 1-byte element sizes, even if the data itself is packed.
129-
size_t packed_size = arg->size_bytes() * bits_per_element / 8;
130-
auto buffer = xla::UnpackIntN(
131-
bits_per_element, static_cast<const char*>(host_input_buffers[i]),
132-
packed_size);
133-
delete[] static_cast<char*>(host_input_buffers[i]);
134-
host_input_buffers[i] = buffer.release();
135-
}
136-
nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept {
137-
delete[] static_cast<char*>(ptr);
138-
});
139125
auto array = xla::nb_numpy_ndarray(dtype, dims, std::nullopt,
140126
host_input_buffers[i], base);
141127
array.attr("flags").attr("writeable") = nb::bool_(false);
@@ -160,7 +146,8 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream,
160146
auto ret = rets.get<xla::ffi::AnyBuffer>(i).value();
161147
auto ptype = static_cast<xla::PrimitiveType>(ret->element_type());
162148
// TODO(b/395428868): Remove this check once we support subbyte types.
163-
if (ptype == xla::S1 || ptype == xla::U1) {
149+
if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 ||
150+
ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) {
164151
return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented,
165152
absl::StrFormat("Unsupported primitive type: %s",
166153
PrimitiveType_Name(ptype)));
@@ -181,45 +168,32 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream,
181168
}
182169
auto expected_shape = maybe_expected_shape.value();
183170
auto expected_strides = xla::ByteStridesForShape(expected_shape);
184-
185-
const void* data = array.data();
186-
size_t size_bytes = array.size() * array.itemsize();
187-
if (strides != expected_strides) {
188-
xla::TransposePlan::Options options;
189-
options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype);
190-
options.dims = absl::Span<int64_t const>(
191-
reinterpret_cast<const int64_t*>(array.shape()), array.ndim());
192-
absl::InlinedVector<int64_t, 4> reversed_layout;
193-
reversed_layout.resize(expected_shape.dimensions().size());
194-
absl::c_reverse_copy(expected_shape.layout().minor_to_major(),
195-
reversed_layout.begin());
196-
options.permutation = reversed_layout;
197-
options.input_layout = xla::TransposePlan::Striding{strides};
198-
auto maybe_plan = transpose_cache->cache.GetOrCreate(options);
199-
if (!maybe_plan.ok()) {
200-
return xla::ffi::Error::Internal(maybe_plan.status().ToString());
201-
}
202-
auto plan = maybe_plan.value();
203-
void* temp = new char[size_bytes];
204-
temp_buffers.push_back(temp);
205-
plan->Execute(data, temp);
206-
data = temp;
171+
if (strides == expected_strides) {
172+
auto gpu_res =
173+
gpuMemcpyAsync(ret->untyped_data(), array.data(), ret->size_bytes(),
174+
gpuMemcpyHostToDevice, stream);
175+
CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync";
176+
continue;
207177
}
208-
209-
// TODO(b/402422886): Remove this once we form Jax arrays directly instead
210-
// of packing/unpacking to/from numpy arrays.
211-
std::unique_ptr<char[]> buffer;
212-
size_t bits_per_element = xla::primitive_util::BitWidth(ptype);
213-
if (bits_per_element == 2 || bits_per_element == 4) {
214-
// NOTE(dsuo): FFI arguments and return buffers are sized assuming
215-
// minimum 1-byte element sizes, even if the data itself is packed.
216-
buffer = xla::PackIntN(bits_per_element, static_cast<const char*>(data),
217-
size_bytes);
218-
data = buffer.get();
219-
size_bytes = (size_bytes * bits_per_element) / 8;
178+
void* temp = new char[ret->size_bytes()];
179+
temp_buffers.push_back(temp);
180+
xla::TransposePlan::Options options;
181+
options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype);
182+
options.dims = absl::Span<int64_t const>(
183+
reinterpret_cast<const int64_t*>(array.shape()), array.ndim());
184+
absl::InlinedVector<int64_t, 4> reversed_layout;
185+
reversed_layout.resize(expected_shape.dimensions().size());
186+
absl::c_reverse_copy(expected_shape.layout().minor_to_major(),
187+
reversed_layout.begin());
188+
options.permutation = reversed_layout;
189+
options.input_layout = xla::TransposePlan::Striding{strides};
190+
auto maybe_plan = transpose_cache->cache.GetOrCreate(options);
191+
if (!maybe_plan.ok()) {
192+
return xla::ffi::Error::Internal(maybe_plan.status().ToString());
220193
}
221-
222-
auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), data, size_bytes,
194+
auto plan = maybe_plan.value();
195+
plan->Execute(array.data(), temp);
196+
auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), temp, ret->size_bytes(),
223197
gpuMemcpyHostToDevice, stream);
224198
CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync";
225199
}

jaxlib/rocm/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,6 @@ cc_library(
539539
"@xla//third_party/python_runtime:headers", # buildcleaner: keep
540540
"@xla//xla:comparison_util",
541541
"@xla//xla:shape_util",
542-
"@xla//xla:util",
543542
"@xla//xla:xla_data_proto_cc",
544543
"@xla//xla/ffi:ffi_api",
545544
"@xla//xla/ffi/api:ffi",

jaxlib/xla/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,6 @@ cc_library(
640640
"@nanobind",
641641
"@xla//third_party/python_runtime:headers", # buildcleaner: keep
642642
"@xla//xla:shape_util",
643-
"@xla//xla:util",
644643
"@xla//xla:xla_data_proto_cc",
645644
"@xla//xla/ffi:ffi_api",
646645
"@xla//xla/ffi/api:ffi",

jaxlib/xla/py_client_cpu.cc

Lines changed: 24 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ limitations under the License.
4141
#include "xla/python/nb_numpy.h"
4242
#include "xla/python/types.h"
4343
#include "xla/shape_util.h"
44-
#include "xla/util.h"
4544
#include "xla/xla_data.pb.h"
4645

4746
namespace nb = nanobind;
@@ -80,7 +79,8 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks,
8079
auto arg = args.get<ffi::AnyBuffer>(i);
8180
auto ptype = static_cast<PrimitiveType>(arg->element_type());
8281
// TODO(b/395428868): Remove this check once we support subbyte types.
83-
if (ptype == S1 || ptype == U1) {
82+
if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 ||
83+
ptype == U2 || ptype == U4) {
8484
return ffi::Error(ffi::ErrorCode::kUnimplemented,
8585
absl::StrFormat("Unsupported primitive type: %s",
8686
PrimitiveType_Name(ptype)));
@@ -96,20 +96,9 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks,
9696
auto dtype = maybe_dtype.value();
9797
auto dims = absl::Span<const int64_t>(arg->dimensions().begin(),
9898
arg->dimensions().size());
99-
// TODO(b/402422886): Remove this once we form Jax arrays directly instead
100-
std::unique_ptr<char[]> buffer;
101-
const void* data = arg->untyped_data();
102-
size_t bits_per_element = xla::primitive_util::BitWidth(ptype);
103-
if (bits_per_element == 2 || bits_per_element == 4) {
104-
// NOTE(dsuo): FFI argument and return buffers are sized assuming
105-
size_t packed_size = arg->size_bytes() * bits_per_element / 8;
106-
buffer = xla::UnpackIntN(bits_per_element, static_cast<const char*>(data),
107-
packed_size);
108-
data = buffer.get();
109-
}
11099
// We pass in data using default numpy layout i.e., std::nullopt.
111100
auto array =
112-
nb_numpy_ndarray(dtype, dims, std::nullopt, data);
101+
nb_numpy_ndarray(dtype, dims, std::nullopt, arg.value().untyped_data());
113102
array.attr("flags").attr("writeable") = nb::bool_(false);
114103
PyTuple_SET_ITEM(nb_args.ptr(), i, array.release().ptr());
115104
}
@@ -130,8 +119,9 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks,
130119
for (size_t i = 0; i < rets.size(); ++i) {
131120
auto ret = rets.get<ffi::AnyBuffer>(i).value();
132121
auto ptype = static_cast<PrimitiveType>(ret->element_type());
133-
// TODO(b/402422886): Remove this once we form Jax arrays directly instead
134-
if (ptype == S1 || ptype == U1) {
122+
// TODO(b/395428868): Remove this check once we support subbyte types.
123+
if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 ||
124+
ptype == U2 || ptype == U4) {
135125
return ffi::Error(ffi::ErrorCode::kUnimplemented,
136126
absl::StrFormat("Unsupported primitive type: %s",
137127
PrimitiveType_Name(ptype)));
@@ -151,55 +141,26 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks,
151141
}
152142
auto expected_shape = maybe_expected_shape.value();
153143
auto expected_strides = ByteStridesForShape(expected_shape);
154-
155-
const void* data = array.data();
156-
std::unique_ptr<char[]> buffer;
157-
size_t bits_per_element = xla::primitive_util::BitWidth(ptype);
158-
size_t size_bytes = array.size() * array.itemsize();
159-
if (strides != expected_strides) {
160-
xla::TransposePlan::Options options;
161-
options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype);
162-
options.dims = absl::Span<const int64_t>(
163-
reinterpret_cast<const int64_t*>(array.shape()), array.ndim());
164-
absl::InlinedVector<int64_t, 4> reversed_layout;
165-
reversed_layout.resize(expected_shape.dimensions().size());
166-
absl::c_reverse_copy(expected_shape.layout().minor_to_major(),
167-
reversed_layout.begin());
168-
options.permutation = reversed_layout;
169-
options.input_layout = xla::TransposePlan::Striding{strides};
170-
auto maybe_plan = transpose_cache->cache.GetOrCreate(options);
171-
if (!maybe_plan.ok()) {
172-
return ffi::Error::Internal(maybe_plan.status().ToString());
173-
}
174-
auto plan = maybe_plan.value();
175-
if (bits_per_element == 2 || bits_per_element == 4) {
176-
// NOTE(dsuo): If the data needs to be unpacked, don't use return buffer
177-
// supplied by FFI directly.
178-
buffer = std::make_unique<char[]>(size_bytes);
179-
plan->Execute(data, buffer.get());
180-
data = buffer.get();
181-
} else {
182-
plan->Execute(data, ret->untyped_data());
183-
data = ret->untyped_data();
184-
}
185-
}
186-
187-
// TODO(b/402422886): Remove this once we form Jax arrays directly instead
188-
// of packing/unpacking to/from numpy arrays.
189-
if (bits_per_element == 2 || bits_per_element == 4) {
190-
// NOTE(dsuo): FFI arguments and return buffers are sized assuming
191-
// minimum 1-byte element sizes, even if the data itself is packed.
192-
buffer = xla::PackIntN(bits_per_element, static_cast<const char*>(data),
193-
size_bytes);
194-
data = buffer.get();
195-
size_bytes = (size_bytes * bits_per_element) / 8;
144+
if (strides == expected_strides) {
145+
std::memcpy(ret->untyped_data(), array.data(), ret->size_bytes());
146+
continue;
196147
}
197-
198-
// Copy data to output buffer if haven't already or modified the data to
199-
// write back.
200-
if (data != ret->untyped_data()) {
201-
std::memcpy(ret->untyped_data(), data, size_bytes);
148+
xla::TransposePlan::Options options;
149+
options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype);
150+
options.dims = absl::Span<const int64_t>(
151+
reinterpret_cast<const int64_t*>(array.shape()), array.ndim());
152+
absl::InlinedVector<int64_t, 4> reversed_layout;
153+
reversed_layout.resize(expected_shape.dimensions_size());
154+
absl::c_reverse_copy(expected_shape.layout().minor_to_major(),
155+
reversed_layout.begin());
156+
options.permutation = reversed_layout;
157+
options.input_layout = xla::TransposePlan::Striding{strides};
158+
auto maybe_plan = transpose_cache->cache.GetOrCreate(options);
159+
if (!maybe_plan.ok()) {
160+
return ffi::Error::Internal(maybe_plan.status().ToString());
202161
}
162+
auto plan = maybe_plan.value();
163+
plan->Execute(array.data(), ret->untyped_data());
203164
}
204165

205166
return ffi::Error::Success();

0 commit comments

Comments
 (0)