Skip to content

Commit 623e032

Browse files
authored
fix transfer_read/write (#152)
1 parent 550fe4b commit 623e032

File tree

5 files changed

+76
-61
lines changed

5 files changed

+76
-61
lines changed

examples/cuda_e2e.ipynb

Lines changed: 4 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -29,65 +29,13 @@
2929
"cell_type": "markdown",
3030
"source": "# Download mlir-python-bindings with CUDA support"
3131
},
32-
{
33-
"metadata": {
34-
"ExecuteTime": {
35-
"end_time": "2025-05-05T16:40:48.337497Z",
36-
"start_time": "2025-05-05T16:40:48.251771Z"
37-
}
38-
},
39-
"cell_type": "code",
40-
"source": [
41-
"from pip._internal.cli import cmdoptions\n",
42-
"from pip._internal.commands import index\n",
43-
"\n",
44-
"cmd = index.IndexCommand(\"blah\", \"\")\n",
45-
"\n",
46-
"options, args = cmd.parse_args(\n",
47-
" [\n",
48-
" \"index\",\n",
49-
" \"versions\",\n",
50-
" \"mlir-python-bindings\",\n",
51-
" \"-f\",\n",
52-
" \"https://makslevental.github.io/wheels\",\n",
53-
" ]\n",
54-
")\n",
55-
"\n",
56-
"\n",
57-
"def get_available_package_versions(self, options, args):\n",
58-
" target_python = cmdoptions.make_target_python(options)\n",
59-
" query = args[0]\n",
60-
"\n",
61-
" with self._build_session(options) as session:\n",
62-
" finder = self._build_package_finder(\n",
63-
" options=options,\n",
64-
" session=session,\n",
65-
" target_python=target_python,\n",
66-
" ignore_requires_python=options.ignore_requires_python,\n",
67-
" )\n",
68-
"\n",
69-
" versions = set(\n",
70-
" candidate.version for candidate in finder.find_all_candidates(query)\n",
71-
" )\n",
72-
"\n",
73-
" return list(versions)\n",
74-
"\n",
75-
"\n",
76-
"def get_latest_cuda_version(all_versions):\n",
77-
" cuda_versions = list(filter(lambda x: \"cuda\" in x.local, all_versions))\n",
78-
" assert len(cuda_versions), \"couldn't find any cuda versions\"\n",
79-
" cuda_versions.sort(key=lambda x: x.release)\n",
80-
" return cuda_versions[0]\n",
81-
"\n"
82-
],
83-
"outputs": [],
84-
"execution_count": 1
85-
},
8632
{
8733
"cell_type": "code",
8834
"source": [
89-
"all_versions = get_available_package_versions(cmd, options, [\"mlir-python-bindings\"])\n",
90-
"latest_cuda_version = get_latest_cuda_version(all_versions)\n",
35+
"!BRANCH=\"${BRANCH:-main}\"\n",
36+
"!echo \"using BRANCH=$BRANCH\"\n",
37+
"!script_address=\"https://raw.githubusercontent.com/makslevental/mlir-python-extras/refs/heads/$BRANCH/scripts/get_latest_gpu_bindings.py\"\n",
38+
"!latest_cuda_version=$(curl $script_address | python -)\n",
9139
"!pip install -q mlir_python_bindings==$latest_cuda_version -f https://makslevental.github.io/wheels\n",
9240
"!pip install git+https://github.com/makslevental/mlir-python-extras@$BRANCH &> /dev/null"
9341
],

examples/mlir_python_extras.ipynb

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
},
2424
"outputs": [],
2525
"source": [
26+
"!BRANCH=\"${BRANCH:-main}\"\n",
27+
"!echo \"using BRANCH=$BRANCH\"\n",
2628
"!pip install mlir-python-bindings -f https://makslevental.github.io/wheels &> /dev/null\n",
2729
"!pip install git+https://github.com/makslevental/mlir-python-extras@$BRANCH &> /dev/null"
2830
]

examples/vectorization_e2e.ipynb

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
},
3232
"outputs": [],
3333
"source": [
34+
"!BRANCH=\"${BRANCH:-main}\"\n",
35+
"!echo \"using BRANCH=$BRANCH\"\n",
3436
"!pip install -q mlir-python-bindings -f https://makslevental.github.io/wheels\n",
3537
"!pip install -q git+https://github.com/makslevental/mlir-python-extras@$BRANCH"
3638
]

mlir/extras/dialects/ext/vector.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def transfer_write(
100100
result=None,
101101
value_to_store=val,
102102
# no clue why they chose this name...
103-
source=dest,
103+
base=dest,
104104
indices=indices,
105105
permutation_map=permutation_map,
106106
mask=mask,
@@ -115,7 +115,7 @@ def transfer_write(
115115

116116
def transfer_read(
117117
vector_t,
118-
source,
118+
base,
119119
indices,
120120
*,
121121
permutation_map=None,
@@ -128,20 +128,20 @@ def transfer_read(
128128
if loc is None:
129129
loc = get_user_code_loc()
130130
if permutation_map is None:
131-
permutation_map = AffineMap.get_minor_identity(source.type.rank, vector_t.rank)
131+
permutation_map = AffineMap.get_minor_identity(base.type.rank, vector_t.rank)
132132
for j, i in enumerate(indices):
133133
if isinstance(i, int):
134134
indices[j] = constant(i, index=True)
135135
if padding is None:
136136
padding = 0
137137
if isinstance(padding, int):
138-
padding = constant(padding, type=source.type.element_type)
138+
padding = constant(padding, type=base.type.element_type)
139139
if in_bounds is None:
140140
raise ValueError("in_bounds cannot be None")
141141

142142
return _transfer_read(
143143
vector=vector_t,
144-
source=source,
144+
base=base,
145145
indices=indices,
146146
permutation_map=permutation_map,
147147
padding=padding,

scripts/get_latest_gpu_bindings.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import argparse
2+
3+
from pip._internal.cli import cmdoptions
4+
from pip._internal.commands import index
5+
6+
7+
def get_available_package_versions(self, options, args):
8+
target_python = cmdoptions.make_target_python(options)
9+
query = args[0]
10+
11+
with self._build_session(options) as session:
12+
finder = self._build_package_finder(
13+
options=options,
14+
session=session,
15+
target_python=target_python,
16+
ignore_requires_python=options.ignore_requires_python,
17+
)
18+
19+
versions = set(
20+
candidate.version for candidate in finder.find_all_candidates(query)
21+
)
22+
23+
return list(versions)
24+
25+
26+
def get_latest_gpu_version(all_versions, platform: str):
27+
bindings_versions = list(
28+
filter(lambda x: x.local and platform in x.local, all_versions)
29+
)
30+
assert len(bindings_versions), "couldn't find any bindings versions"
31+
bindings_versions.sort(key=lambda x: x.release)
32+
return bindings_versions[0]
33+
34+
35+
cmd = index.IndexCommand("blah", "")
36+
37+
38+
def get_latest_gpu_version_name(platform):
39+
options, _args = cmd.parse_args(
40+
[
41+
"index",
42+
"versions",
43+
"mlir-python-bindings",
44+
"--find-links",
45+
"https://makslevental.github.io/wheels",
46+
]
47+
)
48+
options.no_index = True
49+
all_versions = get_available_package_versions(
50+
cmd, options, ["mlir-python-bindings"]
51+
)
52+
return get_latest_gpu_version(all_versions, platform)
53+
54+
55+
def main():
56+
parser = argparse.ArgumentParser()
57+
parser.add_argument("platform", nargs="?", default="cuda")
58+
platform = parser.parse_args().platform
59+
print(get_latest_gpu_version_name(platform))
60+
61+
62+
if __name__ == "__main__":
63+
main()

0 commit comments

Comments
 (0)