Skip to content

Commit 5e1f60d

Browse files
authored
Add widen_wrap opt (#1511)
* Add widen_wrap opt * Update Compiler.jl * Update Compiler.jl * fix rebase * fix * fix * more rebase * fix * bump commit * Update WORKSPACE * Update WORKSPACE * Update Project.toml
1 parent 48ec28b commit 5e1f60d

File tree

4 files changed

+79
-65
lines changed

4 files changed

+79
-65
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.151"
4+
version = "0.2.152"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -97,7 +97,7 @@ PythonCall = "0.9.25"
9797
Random = "1.10"
9898
Random123 = "1.7"
9999
ReactantCore = "0.1.15"
100-
Reactant_jll = "0.0.225"
100+
Reactant_jll = "0.0.226"
101101
ScopedValues = "1.3.0"
102102
Scratch = "1.2"
103103
Sockets = "1.10"

deps/ReactantExtra/API.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,8 +1435,7 @@ extern "C" HeldIfrtArray *ifrt_client_make_array_from_host_buffer(
14351435
std::nullopt, // byte_strides
14361436
sharding->obj(),
14371437
static_cast<ifrt::Client::HostBufferSemantics>(c_semantics),
1438-
[] {}, // on_done_with_host_buffer,
1439-
client->CreateUserContext())));
1438+
[] {})));
14401439
}
14411440

14421441
extern "C" HeldIfrtArray *ifrt_client_make_single_shard_array_from_host_buffer(
@@ -1846,7 +1845,7 @@ ifrt_CreateDeviceListFromDevices(ifrt::Client *client,
18461845
ifrt::Device **device_list,
18471846
int32_t num_devices) {
18481847
absl::Span<ifrt::Device *const> devices(device_list, num_devices);
1849-
return client->MakeDeviceList(devices);
1848+
return MyValueOrThrow(client->MakeDeviceList(devices));
18501849
}
18511850

18521851
extern "C" ifrt::Memory *ifrt_DeviceGetDefaultMemory(ifrt::Device *device) {
@@ -2660,8 +2659,7 @@ extern "C" HeldIfrtArray *ifrt_make_array_from_host_buffer_shards(
26602659
sharding);
26612660
auto arrays = MyValueOrThrow(client->MakeArraysFromHostBufferShards(
26622661
absl::MakeSpan(&spec, 1),
2663-
static_cast<ifrt::Client::HostBufferSemantics>(c_host_buffer_semantics),
2664-
client->CreateUserContext()));
2662+
static_cast<ifrt::Client::HostBufferSemantics>(c_host_buffer_semantics)));
26652663
return reactant::capture(arrays[0]);
26662664
}
26672665

deps/ReactantExtra/WORKSPACE

Lines changed: 67 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -4,52 +4,24 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"
44

55
NSYNC_SHA256 = ""
66

7+
ENZYMEXLA_COMMIT = "30194bfd56e844ebf4e5fa8a946efa26d344d1cd"
8+
9+
ENZYMEXLA_SHA256 = ""
10+
711
http_archive(
812
name = "nsync",
913
sha256 = NSYNC_SHA256,
1014
strip_prefix = "nsync-" + NSYNC_COMMIT,
1115
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
1216
)
1317

14-
ENZYMEXLA_COMMIT = "0dc3ef87806ab3c9a695fe5c5689f8e2baf0d6cb"
15-
16-
ENZYMEXLA_SHA256 = ""
17-
1818
http_archive(
1919
name = "enzyme_ad",
2020
sha256 = ENZYMEXLA_SHA256,
2121
strip_prefix = "Enzyme-JAX-" + ENZYMEXLA_COMMIT,
2222
urls = ["https://github.com/EnzymeAD/Enzyme-JAX/archive/{commit}.tar.gz".format(commit = ENZYMEXLA_COMMIT)],
2323
)
2424

25-
# Hedron's Compile Commands Extractor for Bazel
26-
# https://github.com/hedronvision/bazel-compile-commands-extractor
27-
http_archive(
28-
name = "hedron_compile_commands",
29-
strip_prefix = "bazel-compile-commands-extractor-4f28899228fb3ad0126897876f147ca15026151e",
30-
31-
# Replace the commit hash (0e990032f3c5a866e72615cf67e5ce22186dcb97) in both places (below) with the latest (https://github.com/hedronvision/bazel-compile-commands-extractor/commits/main), rather than using the stale one here.
32-
# Even better, set up Renovate and let it do the work for you (see "Suggestion: Updates" in the README).
33-
url = "https://github.com/hedronvision/bazel-compile-commands-extractor/archive/4f28899228fb3ad0126897876f147ca15026151e.tar.gz",
34-
# When you first run this tool, it'll recommend a sha256 hash to put here with a message like: "DEBUG: Rule 'hedron_compile_commands' indicated that a canonical reproducible form can be obtained by modifying arguments sha256 = ..."
35-
)
36-
37-
load("@hedron_compile_commands//:workspace_setup.bzl", "hedron_compile_commands_setup")
38-
39-
hedron_compile_commands_setup()
40-
41-
load("@hedron_compile_commands//:workspace_setup_transitive.bzl", "hedron_compile_commands_setup_transitive")
42-
43-
hedron_compile_commands_setup_transitive()
44-
45-
load("@hedron_compile_commands//:workspace_setup_transitive_transitive.bzl", "hedron_compile_commands_setup_transitive_transitive")
46-
47-
hedron_compile_commands_setup_transitive_transitive()
48-
49-
load("@hedron_compile_commands//:workspace_setup_transitive_transitive_transitive.bzl", "hedron_compile_commands_setup_transitive_transitive_transitive")
50-
51-
hedron_compile_commands_setup_transitive_transitive_transitive()
52-
5325
load("@enzyme_ad//:workspace.bzl", "ENZYME_COMMIT", "ENZYME_SHA256", "JAX_COMMIT", "JAX_SHA256", "XLA_PATCHES")
5426

5527
CUPTI_OLD = [
@@ -177,25 +149,6 @@ http_archive(
177149
urls = ["https://github.com/giordano/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)],
178150
)
179151

180-
load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules")
181-
182-
python_init_rules()
183-
184-
load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories")
185-
186-
python_init_repositories(
187-
requirements = {
188-
"3.9": "//build:requirements_lock_3_9.txt",
189-
"3.10": "//build:requirements_lock_3_10.txt",
190-
"3.11": "//build:requirements_lock_3_11.txt",
191-
"3.12": "//build:requirements_lock_3_12.txt",
192-
"3.13": "//build:requirements_lock_3_13.txt",
193-
},
194-
)
195-
196-
load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains")
197-
198-
python_init_toolchains()
199152
#
200153
# load("@xla//third_party/py:python_init_pip.bzl", "python_init_pip")
201154
# python_init_pip()
@@ -233,6 +186,34 @@ load("@jax//third_party/xla:workspace.bzl", jax_xla_workspace = "repo")
233186

234187
jax_xla_workspace()
235188

189+
load("@xla//:workspace4.bzl", "xla_workspace4")
190+
191+
xla_workspace4()
192+
193+
load("@xla//:workspace3.bzl", "xla_workspace3")
194+
195+
xla_workspace3()
196+
197+
load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules")
198+
199+
python_init_rules()
200+
201+
load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories")
202+
203+
python_init_repositories(
204+
requirements = {
205+
"3.9": "//build:requirements_lock_3_9.txt",
206+
"3.10": "//build:requirements_lock_3_10.txt",
207+
"3.11": "//build:requirements_lock_3_11.txt",
208+
"3.12": "//build:requirements_lock_3_12.txt",
209+
"3.13": "//build:requirements_lock_3_13.txt",
210+
},
211+
)
212+
213+
load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains")
214+
215+
python_init_toolchains()
216+
236217
load("@xla//third_party/llvm:workspace.bzl", llvm = "repo")
237218

238219
llvm("llvm-raw")
@@ -244,13 +225,6 @@ llvm_configure(
244225
targets = LLVM_TARGETS,
245226
)
246227

247-
load("@xla//:workspace4.bzl", "xla_workspace4")
248-
249-
xla_workspace4()
250-
251-
load("@xla//:workspace3.bzl", "xla_workspace3")
252-
253-
xla_workspace3()
254228

255229
load("@xla//:workspace2.bzl", "xla_workspace2")
256230

@@ -285,6 +259,12 @@ load("@jax//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
285259

286260
flatbuffers()
287261

262+
load("@jax//:test_shard_count.bzl", "test_shard_count_repository")
263+
264+
test_shard_count_repository(
265+
name = "test_shard_count",
266+
)
267+
288268
load(
289269
"@rules_ml_toolchain//cc_toolchain/deps:cc_toolchain_deps.bzl",
290270
"cc_toolchain_deps",
@@ -367,3 +347,32 @@ load(
367347
nvshmem_redist_init_repository(
368348
nvshmem_redistributions = NVSHMEM_REDISTRIBUTIONS,
369349
)
350+
351+
# Hedron's Compile Commands Extractor for Bazel
352+
# https://github.com/hedronvision/bazel-compile-commands-extractor
353+
http_archive(
354+
name = "hedron_compile_commands",
355+
strip_prefix = "bazel-compile-commands-extractor-4f28899228fb3ad0126897876f147ca15026151e",
356+
357+
# Replace the commit hash (0e990032f3c5a866e72615cf67e5ce22186dcb97) in both places (below) with the latest (https://github.com/hedronvision/bazel-compile-commands-extractor/commits/main), rather than using the stale one here.
358+
# Even better, set up Renovate and let it do the work for you (see "Suggestion: Updates" in the README).
359+
url = "https://github.com/hedronvision/bazel-compile-commands-extractor/archive/4f28899228fb3ad0126897876f147ca15026151e.tar.gz",
360+
# When you first run this tool, it'll recommend a sha256 hash to put here with a message like: "DEBUG: Rule 'hedron_compile_commands' indicated that a canonical reproducible form can be obtained by modifying arguments sha256 = ..."
361+
)
362+
363+
load("@hedron_compile_commands//:workspace_setup.bzl", "hedron_compile_commands_setup")
364+
365+
hedron_compile_commands_setup()
366+
367+
load("@hedron_compile_commands//:workspace_setup_transitive.bzl", "hedron_compile_commands_setup_transitive")
368+
369+
hedron_compile_commands_setup_transitive()
370+
371+
load("@hedron_compile_commands//:workspace_setup_transitive_transitive.bzl", "hedron_compile_commands_setup_transitive_transitive")
372+
373+
hedron_compile_commands_setup_transitive_transitive()
374+
375+
load("@hedron_compile_commands//:workspace_setup_transitive_transitive_transitive.bzl", "hedron_compile_commands_setup_transitive_transitive_transitive")
376+
377+
hedron_compile_commands_setup_transitive_transitive_transitive()
378+

src/Compiler.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,6 +1010,13 @@ function optimization_passes(
10101010
"const_prop_through_barrier<16>",
10111011
"concat_const_prop<1>($max_constant_threshold)",
10121012
"dynamic_update_slice_const_prop($max_constant_threshold)",
1013+
"widen_wrap",
1014+
"widen_extend",
1015+
"elementwise_pad",
1016+
"compare_negate_const_simplify",
1017+
"select_simplify",
1018+
"concatenate_subtract_to_subtract_pad",
1019+
"concatenate_broadcast_in_dim"
10131020
],
10141021
)
10151022

0 commit comments

Comments
 (0)