Skip to content

Commit 9f79799

Browse files
gneculaGoogle-ML-Automation
authored andcommitted
Remove old backward compatibility mode for old PRGN custom call on GPU
The backend support for the new custom call was added on June 28th, 2024 (jax-ml#20997). PiperOrigin-RevId: 723077990
1 parent 6738986 commit 9f79799

File tree

9 files changed

+3
-139
lines changed

9 files changed

+3
-139
lines changed

jax/_src/export/_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1042,7 +1042,7 @@ def _check_lowering(lowering) -> None:
10421042
*_CPU_FFI_KERNELS,
10431043
*_GPU_FFI_KERNELS,
10441044
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
1045-
"cu_threefry2x32", "cu_threefry2x32_ffi",
1045+
"cu_threefry2x32_ffi",
10461046
# Triton IR does not guarantee stability.
10471047
# "__gpu$xla.gpu.triton",
10481048
# cholesky on CPU

jax/_src/internal_test_util/export_back_compat_test_data/cuda_threefry2x32.py

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -15,63 +15,6 @@
1515
import datetime
1616
from numpy import array, float32, uint32
1717

18-
# Pasted from the test output (see back_compat_test.py module docstring)
19-
# TODO(b/338022728): remove after 6 months
20-
data_2023_03_15 = dict(
21-
testdata_version=1,
22-
platform='cuda',
23-
custom_call_targets=['cu_threefry2x32'],
24-
serialized_date=datetime.date(2023, 3, 15),
25-
inputs=(array([42, 43], dtype=uint32),),
26-
expected_outputs=(array([[0.42591238, 0.0769949 , 0.44370103, 0.72904015],
27-
[0.17879379, 0.81439507, 0.00191903, 0.68608475]], dtype=float32),),
28-
mlir_module_text=r"""
29-
module @jit_func {
30-
func.func public @main(%arg0: tensor<2xui32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<2x4xf32> {jax.result_info = ""}) {
31-
%0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
32-
%1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor<f32>) -> tensor<1x1xf32>
33-
%2 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
34-
%3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<f32>) -> tensor<1x1xf32>
35-
%4 = stablehlo.iota dim = 0 : tensor<8xui32>
36-
%5 = "stablehlo.slice"(%arg0) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xui32>) -> tensor<1xui32>
37-
%6 = stablehlo.reshape %5 : (tensor<1xui32>) -> tensor<ui32>
38-
%7 = "stablehlo.slice"(%arg0) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xui32>) -> tensor<1xui32>
39-
%8 = stablehlo.reshape %7 : (tensor<1xui32>) -> tensor<ui32>
40-
%9 = "stablehlo.slice"(%4) {limit_indices = dense<4> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<8xui32>) -> tensor<4xui32>
41-
%10 = "stablehlo.slice"(%4) {limit_indices = dense<8> : tensor<1xi64>, start_indices = dense<4> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<8xui32>) -> tensor<4xui32>
42-
%11 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor<ui32>) -> tensor<4xui32>
43-
%12 = stablehlo.broadcast_in_dim %8, dims = [] : (tensor<ui32>) -> tensor<4xui32>
44-
%13 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<4xui32>) -> tensor<4xui32>
45-
%14 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<4xui32>) -> tensor<4xui32>
46-
%15 = stablehlo.custom_call @cu_threefry2x32(%11, %12, %13, %14) {api_version = 2 : i32, backend_config = "\04\00\00\00\00\00\00\00", operand_layouts = [dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor<4xui32>, tensor<4xui32>, tensor<4xui32>, tensor<4xui32>) -> tuple<tensor<4xui32>, tensor<4xui32>>
47-
%16 = stablehlo.get_tuple_element %15[0] : (tuple<tensor<4xui32>, tensor<4xui32>>) -> tensor<4xui32>
48-
%17 = stablehlo.get_tuple_element %15[1] : (tuple<tensor<4xui32>, tensor<4xui32>>) -> tensor<4xui32>
49-
%18 = stablehlo.concatenate %16, %17, dim = 0 : (tensor<4xui32>, tensor<4xui32>) -> tensor<8xui32>
50-
%19 = stablehlo.reshape %18 : (tensor<8xui32>) -> tensor<2x4xui32>
51-
%20 = stablehlo.constant dense<9> : tensor<ui32>
52-
%21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor<ui32>) -> tensor<2x4xui32>
53-
%22 = stablehlo.shift_right_logical %19, %21 : tensor<2x4xui32>
54-
%23 = stablehlo.constant dense<1065353216> : tensor<ui32>
55-
%24 = stablehlo.broadcast_in_dim %23, dims = [] : (tensor<ui32>) -> tensor<2x4xui32>
56-
%25 = stablehlo.or %22, %24 : tensor<2x4xui32>
57-
%26 = stablehlo.bitcast_convert %25 : (tensor<2x4xui32>) -> tensor<2x4xf32>
58-
%27 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
59-
%28 = stablehlo.broadcast_in_dim %27, dims = [] : (tensor<f32>) -> tensor<2x4xf32>
60-
%29 = stablehlo.subtract %26, %28 : tensor<2x4xf32>
61-
%30 = stablehlo.subtract %3, %1 : tensor<1x1xf32>
62-
%31 = stablehlo.broadcast_in_dim %30, dims = [0, 1] : (tensor<1x1xf32>) -> tensor<2x4xf32>
63-
%32 = stablehlo.multiply %29, %31 : tensor<2x4xf32>
64-
%33 = stablehlo.broadcast_in_dim %1, dims = [0, 1] : (tensor<1x1xf32>) -> tensor<2x4xf32>
65-
%34 = stablehlo.add %32, %33 : tensor<2x4xf32>
66-
%35 = stablehlo.broadcast_in_dim %1, dims = [0, 1] : (tensor<1x1xf32>) -> tensor<2x4xf32>
67-
%36 = stablehlo.maximum %35, %34 : tensor<2x4xf32>
68-
return %36 : tensor<2x4xf32>
69-
}
70-
}
71-
""",
72-
mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x013\x05\x01\x05\x01\x03\x05\x03#\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%'\x032\x02\xe1)\x01\x9b\x17\x07\x13\x0f\x0b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0f\x13\x0f\x13\x0b\x0f\x0f\x0f\x0f\x0f\x13\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0b\x13\x0b\x0f\x0b#\x0f\x0b\x0b#\x0f\x0b#\x0f\x0b#\x0f\x0b\x0bK\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x03G///\x0f/\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x1f\x0f\x1f//\x0b\x0b\x0b\x0b\x1b\x13\x0f\x0f\x1f\x1fO\x03)\x17\x13\x07\x0f\x0f\x13\x17\x07\x07\x17\x13\x13\x13\x07\x17\x13\x13\x13\x07\x13\x02\xb6\x07\x17?\xb2\x03\x01\x1f\x03\x03\x11\xc3\x1dc\x01\x05)\x05+\x05-\x05/\x051\x1d\x93\x01\x03\x03\x11\xdf\x053\x1d=\x01\x03\x03\t\xc5\x1dO\x01\x03\x03\x11\x9f\x055\x1d\x89\x01\x1d\x8d\x01\x1d\x95\x01\x1d\x97\x01\x1d\x99\x01\x03\x03\x17/\x057\x03\x0b3\xa75\xb37\xb5\x17\xbd9\xbf\x059\x05;\x05=\x05?\x03\x03\t\xc1\x05A\x05C\x03\x03C\xa1\x05E\x1dG\x01\x05G\x03\x07\x0b\x9b\r\x9f\x0f\x9b\x1dM\x01\x05I\x05K\x03\x07\x0b\xc7\r\x9b\x0f\x9b\x1dU\x01\x05M\x03\x07\x0b\xa3\r\x9f\x0f\x9b\x1d[\x01\x05O\x03\x07\x0b\xc9\r\xa3\x0f\x9b\x1da\x01\x05Q\x05S\x03\x11g\xcbi\xcdk\xcfm\xa5o\xd1q\xd3s\xa5u\xd5\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x05c\x03\x03!\xd7\x03\x03!\xd9\x03\x03}\xa1\x05e\x1d\x81\x01\x05g\x1d\x85\x01\x05i\x03\x03\t\xdb\x05k\x03\x03\t\xdd\x05m\x1d\x91\x01\x05o\x05q\x05s\x05u\x05w\x1f\x0b\x11\x01\x00\x00\x00\x00\x00\x00\x00\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x0b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x13\x0f\x01\x1f\x0b\x11\x04\x00\x00\x00\x00\x00\x00\x00\x03\x01\x03\x03\xa9\r\x05\xab\xad\xaf\xb1\x1dy\x1d{\x1d}\x1d\x7f#\x1d\x03\x03\xb7\r\x03\xb9\xbb\x1d\x81\x1d\x83\x1d\x85\x1d\x87\x1f\t\t\x00\x00\x00\x00\x1f\x1f\x01\x1f\t\t\x00\x00\x80?\x1f\x0b\x11\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x0b\x11\x08\x00\x00\x00\x00\x00\x00\x00\x0b\x05\x1d\x89\x1d\x8b\x05\x01\x03\t\x9d\x9d\x9d\x9d\x03\x05\x9d\x9d\x13\x1b\x01\x13\x1b\x05\x1f\x07\t\t\x00\x00\x00\x1f\x07\t\x00\x00\x80?\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00)\x05\t\x11\x11)\x03\x11\x05%)\x01\x05)\x01\x11)\x03\x05\x0f)\x05\t\x11\x05\x1d\t)\x05\x05\x05\x11)\x03\t\x05)\x03!\x05)\x03\x05\x05\x1b\x11\x03\x15\x03\x01)\x03\x01\x0f/\x05\x03\x03)\x03\x05%\x13)\x03\t\x0f\x04\xd6\x04\x05\x01\x11\x03-\x07\x03\x01\x05\x0f\x11\x031\x05\x03M\x9b\x03\x15\x03\x05\x03\x03;\x03\t\x03\x07\x19\x05\x03\x13\x03\x03\x05\x03\x03\x1b\x03\t\x03\x07\x19\x05\x03\x13\x03\x07\x11\x03EA\x03\x17\x07\x07KI\x03\x19\x03\x01\t\x06\x1d\x03\x07\x03\r\x07\x07SQ\x03\x19\x03\x01\t\x06\x1d\x03\x07\x03\x11\x07\x07YW\x03\x03\x03\x0b\x07\x07_]\x03\x03\x03\x0b\x03\x07\x07\x05\x03\x03\x03\x0f\x03\x07\x07\x05\x03\x03\x03\x13\x03\x07\x07\x1f\x03\x03\x03\x15\x03\x07\x07\x1f\x03\x03\x03\x17\x13\x07\x07e\x03!\t\x19\x1b\x1d\x1f\x0b\x07\x07w\x03\x03\x03!\x0b\x07\x07y\x03\x03\x03!\x15\x07\x7f{\x03\x17\x05#%\t\x06\x83\x03\r\x03'\x05\x03\x03\x87\x03\x07\x03\x07#\x05\x03\r\x03+\x17\x06#\x03\r\x05)-\x05\x03\x03\x8b\x03\x07\x03\x07%\x05\x03\r\x031\x19\x06%\x03\r\x05/3\x1b\x06\x8f\x03\x01\x035\x05\x03\x03\x1b\x03\t\x03\x07\x13\x05\x03\x01\x039\r\x06\x13\x03\x01\x057;\r\x06\x13\x03\x13\x05\t\x05\x03\x07'\x15\x03\x01\x03?\x1d\x06'\x03\x01\x05=A\x03\x07)\x15\x03\x01\x03\x05\x1f\x06)\x03\x01\x05CE\x03\x07+\x15\x03\x01\x03\x05!\x06+\x03\x01\x05IG#\x04\x03\x03K\x06\x03\x01\x05\x01\x00N\x19\x8d!\x13\x0f\x0b\x03!\x1b\x1d\x05\x1b1111y/Q}[\x15\x1f/!!)#\x1f\x19C\x9d\x9d\x9d[\x9d}\x1f\x83\x97\x1f\x15\x1d\x15\x13\r\x13+\x11\x1d\x1d\r\x15\x17\x0f\x19'\r/\x1f\x1f\x11\x11\x19+\x17\x13\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00slice_v1\x00reshape_v1\x00get_tuple_element_v1\x00subtract_v1\x00func_v1\x00iota_v1\x00custom_call_v1\x00concatenate_v1\x00shift_right_logical_v1\x00or_v1\x00bitcast_convert_v1\x00multiply_v1\x00add_v1\x00maximum_v1\x00return_v1\x00value\x00limit_indices\x00start_indices\x00strides\x00broadcast_dimensions\x00sym_name\x00index\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/broadcast_in_dim[shape=(1, 1) broadcast_dimensions=()]\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=uint32 shape=(8,) dimension=0]\x00jit(func)/jit(main)/slice[start_indices=(0,) limit_indices=(1,) strides=(1,)]\x00jit(func)/jit(main)/squeeze[dimensions=(0,)]\x00jit(func)/jit(main)/slice[start_indices=(1,) limit_indices=(2,) strides=(1,)]\x00jit(func)/jit(main)/slice[start_indices=(0,) limit_indices=(4,) strides=None]\x00jit(func)/jit(main)/slice[start_indices=(4,) limit_indices=(8,) strides=None]\x00jit(func)/jit(main)/threefry2x32\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00dimension\x00jit(func)/jit(main)/concatenate[dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(2, 4) dimensions=None]\x00jit(func)/jit(main)/shift_right_logical\x00jit(func)/jit(main)/or\x00jit(func)/jit(main)/bitcast_convert_type[new_dtype=float32]\x00jit(func)/jit(main)/sub\x00jit(func)/jit(main)/mul\x00jit(func)/jit(main)/add\x00jit(func)/jit(main)/max\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00\x00main\x00public\x00\x04\x00\x00\x00\x00\x00\x00\x00\x00cu_threefry2x32\x00",
73-
xla_call_module_version=4,
74-
) # End paste
7518

7619
# Pasted from the test output (see export_back_compat_test_util.py module docstring)
7720
data_2024_07_30 = dict(

jaxlib/gpu/gpu_kernels.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn", RNNForward, "CUDA");
4141
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn_bwd", RNNBackward, "CUDA");
4242
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_cholesky_update", CholeskyUpdate,
4343
"CUDA");
44-
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_threefry2x32", ThreeFry2x32,
45-
"CUDA");
4644
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_getrf", Getrf, "CUDA");
4745
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_getrf_ffi", "CUDA",
4846
GetrfFfi);

jaxlib/gpu/prng.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,10 @@ namespace {
2323

2424
namespace nb = nanobind;
2525

26-
std::string BuildThreeFry2x32Descriptor(std::int64_t n) {
27-
return PackDescriptorAsString(ThreeFry2x32Descriptor{n});
28-
}
2926
nb::dict Registrations() {
3027
nb::dict dict;
3128
dict[JAX_GPU_PREFIX "_threefry2x32_ffi"] =
3229
EncapsulateFfiHandler(ThreeFry2x32Ffi);
33-
// TODO(b/338022728): remove after 6 months
34-
dict[JAX_GPU_PREFIX "_threefry2x32"] = EncapsulateFunction(ThreeFry2x32);
3530
return dict;
3631
}
3732

jaxlib/gpu/prng_kernels.cc

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,29 +33,6 @@ namespace JAX_GPU_NAMESPACE {
3333

3434
namespace ffi = xla::ffi;
3535

36-
namespace {
37-
38-
// TODO(b/338022728): old custom call target, remove after 6 months
39-
absl::Status ThreeFry2x32_(gpuStream_t stream, void** buffers,
40-
const char* opaque, std::size_t opaque_len) {
41-
auto s = UnpackDescriptor<ThreeFry2x32Descriptor>(opaque, opaque_len);
42-
JAX_RETURN_IF_ERROR(s.status());
43-
LaunchThreeFry2x32Kernel(stream, buffers, **s);
44-
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError()));
45-
return absl::OkStatus();
46-
}
47-
48-
} // namespace
49-
50-
// TODO(b/338022728): remove after 6 months
51-
void ThreeFry2x32(gpuStream_t stream, void** buffers, const char* opaque,
52-
size_t opaque_len, XlaCustomCallStatus* status) {
53-
auto s = ThreeFry2x32_(stream, buffers, opaque, opaque_len);
54-
if (!s.ok()) {
55-
std::string_view message = s.message();
56-
XlaCustomCallStatusSetFailure(status, message.data(), message.length());
57-
}
58-
}
5936

6037
namespace {
6138
ffi::Error ThreeFry2x32Impl(gpuStream_t stream,

jaxlib/gpu/prng_kernels.cu.cc

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -121,35 +121,6 @@ void LaunchThreeFry2x32KernelFfi(gpuStream_t stream,
121121
out1, n, nullptr);
122122
}
123123

124-
// TODO(b/338022728): remove after 6 months
125-
void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers,
126-
ThreeFry2x32Descriptor descriptor) {
127-
std::array<const std::uint32_t*, 2> keys;
128-
keys[0] = reinterpret_cast<const std::uint32_t*>(buffers[0]);
129-
keys[1] = reinterpret_cast<const std::uint32_t*>(buffers[1]);
130-
std::array<const std::uint32_t*, 2> data;
131-
data[0] = reinterpret_cast<const std::uint32_t*>(buffers[2]);
132-
data[1] = reinterpret_cast<const std::uint32_t*>(buffers[3]);
133-
std::int64_t n = descriptor.n;
134-
int output_idx = 4;
135-
std::int64_t* n_ptr = nullptr;
136-
if (n < 0) {
137-
// n is an operand in device memory.
138-
n_ptr = reinterpret_cast<std::int64_t*>(buffers[4]);
139-
output_idx = 5;
140-
}
141-
142-
std::array<std::uint32_t*, 2> out;
143-
out[0] = reinterpret_cast<std::uint32_t*>(buffers[output_idx]);
144-
out[1] = reinterpret_cast<std::uint32_t*>(buffers[output_idx + 1]);
145-
const int block_dim = 128;
146-
const std::int64_t grid_dim =
147-
n < 0 ? 32
148-
: std::min<std::int64_t>(1024, (n + block_dim - 1) / block_dim);
149-
ThreeFry2x32Kernel<<<grid_dim, block_dim, /*dynamic_shared_mem_bytes=*/0,
150-
stream>>>(keys[0], keys[1], data[0], data[1], out[0],
151-
out[1], n, n_ptr);
152-
}
153124

154125
} // namespace JAX_GPU_NAMESPACE
155126
} // namespace jax

jaxlib/gpu/prng_kernels.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,6 @@ limitations under the License.
2626
namespace jax {
2727
namespace JAX_GPU_NAMESPACE {
2828

29-
// TODO(b/338022728): remove after 6 months
30-
struct ThreeFry2x32Descriptor {
31-
std::int64_t n; // If -1 then the length is passed as a 5th operand
32-
};
33-
34-
// TODO(b/338022728): remove after 6 months
35-
void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers,
36-
ThreeFry2x32Descriptor descriptor);
37-
38-
// TODO(b/338022728): remove after 6 months
39-
void ThreeFry2x32(gpuStream_t stream, void** buffers, const char* opaque,
40-
size_t opaque_len, XlaCustomCallStatus* status);
41-
4229
void LaunchThreeFry2x32KernelFfi(gpuStream_t stream,
4330
std::int64_t n,
4431
std::uint32_t *keys0, std::uint32_t *keys1,

jaxlib/gpu_prng.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,8 @@
3636

3737
if _cuda_prng:
3838
for _name, _value in _cuda_prng.registrations().items():
39-
# TODO(b/338022728): remove after 6 months, always api_version=1
40-
api_version = 1 if "_ffi" in _name else 0
4139
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
42-
api_version=api_version)
40+
api_version=1)
4341

4442
for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
4543
try:

tests/export_back_compat_test.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def test_custom_call_coverage(self):
136136
cpu_eig_lapack_geev.data_2023_06_19,
137137
cpu_eigh_lapack_syev.data_2023_03_17,
138138
cpu_qr_lapack_geqrf.data_2023_03_17,
139-
cuda_threefry2x32.data_2023_03_15, cuda_threefry2x32.data_2024_07_30,
139+
cuda_threefry2x32.data_2024_07_30,
140140
cpu_lu_lapack_getrf.data_2023_06_14,
141141
cuda_lu_pivots_to_permutation.data_2024_08_08,
142142
cuda_lu_cusolver_getrf.data_2024_08_19,
@@ -856,11 +856,6 @@ def test_cuda_threefry2x32(self):
856856
def func(x):
857857
return jax.random.uniform(x, (2, 4), dtype=np.float32)
858858

859-
# TODO(b/338022728): remove after 6 months
860-
data = self.load_testdata(cuda_threefry2x32.data_2023_03_15)
861-
self.run_one_test(func, data,
862-
expect_current_custom_calls=["cu_threefry2x32_ffi"])
863-
864859
data = self.load_testdata(cuda_threefry2x32.data_2024_07_30)
865860
self.run_one_test(func, data)
866861

0 commit comments

Comments
 (0)