|
32 | 32 | assert os.path.isfile(dep_versions_path) |
33 | 33 | assert os.path.isfile(catalyst_init_path) |
34 | 34 |
|
35 | | -url = f"https://raw.githubusercontent.com/google/jax/jaxlib-v{jax_version}/WORKSPACE" |
| 35 | +url = f"https://raw.githubusercontent.com/jax-ml/jax/jax-v{jax_version}/WORKSPACE" |
36 | 36 | response = requests.get(url) |
37 | 37 | match = re.search(r'strip_prefix = "xla-([a-zA-Z0-9]*)"', response.text) |
38 | 38 | if not match: |
39 | | - url = f"https://raw.githubusercontent.com/google/jax/jaxlib-v{jax_version}/third_party/xla/workspace.bzl" |
| 39 | + url = f"https://raw.githubusercontent.com/jax-ml/jax/jax-v{jax_version}/third_party/xla/workspace.bzl" |
40 | 40 | response = requests.get(url) |
41 | 41 | match = re.search(r'XLA_COMMIT = "([a-zA-Z0-9]*)"', response.text) |
42 | 42 | xla_commit = match.group(1) |
|
67 | 67 | response = requests.get(url).json() |
68 | 68 | hlo_commit = response["items"][0]["sha"] |
69 | 69 |
|
70 | | -existing_text = open(dep_versions_path, "r", encoding="UTF-8").read() |
71 | | -match = re.search(r"enzyme=([a-zA-Z0-9]*)", existing_text) |
72 | | -enzyme_commit = match.group(1) |
73 | | - |
74 | | -with open(dep_versions_path, "w", encoding="UTF-8") as f: |
75 | | - f.write( |
76 | | - f"""\ |
77 | | -jax={jax_version} |
78 | | -mhlo={hlo_commit} |
79 | | -llvm={llvm_commit} |
80 | | -enzyme={enzyme_commit} |
81 | | -""" |
82 | | - ) |
83 | | - |
84 | 70 | quote = '"' |
85 | | -cmd = f"sed -i 's/_jaxlib_version = {quote}\([0-9.]\+\){quote}/_jaxlib_version = {quote}{jax_version}{quote}/g' {catalyst_init_path}" |
86 | | -res = os.system(cmd) |
87 | | -assert res == 0 |
| 71 | +# Update each version using sed |
| 72 | +cmds = [ |
| 73 | + f"sed -i '' 's/^jax=.*/jax={jax_version}/' {dep_versions_path}", |
| 74 | + f"sed -i '' 's/^mhlo=.*/mhlo={hlo_commit}/' {dep_versions_path}", |
| 75 | + f"sed -i '' 's/^llvm=.*/llvm={llvm_commit}/' {dep_versions_path}", |
| 76 | + # Update jaxlib version in __init__.py |
| 77 | + rf"sed -i '' 's/_jaxlib_version = {quote}\([0-9.]\+\){quote}/_jaxlib_version = {quote}{jax_version}{quote}/g' {catalyst_init_path}", |
| 78 | +] |
| 79 | + |
| 80 | +for cmd in cmds: |
| 81 | + res = os.system(cmd) |
| 82 | + assert res == 0 |
0 commit comments