Skip to content

Commit eab0de4

Browse files
authored
[cute dsl] optimize cute dsl make_ptr perf (#1607)
1 parent bbf4035 commit eab0de4

File tree

2 files changed

+159
-4
lines changed

2 files changed

+159
-4
lines changed

flashinfer/cute_dsl/blockscaled_gemm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@
4040
import functools
4141
from cutlass._mlir import ir
4242
from cutlass.cute.nvgpu import cpasync, tcgen05
43-
from cutlass.cute.runtime import from_dlpack, make_ptr
43+
from cutlass.cute.runtime import from_dlpack
44+
4445
from cutlass.cutlass_dsl import (
4546
Int32,
4647
Integer,
@@ -49,7 +50,7 @@
4950
new_from_mlir_values,
5051
)
5152
from cutlass.utils.static_persistent_tile_scheduler import WorkTileInfo
52-
from .utils import get_cutlass_dtype, cutlass_to_torch_dtype, get_num_sm
53+
from .utils import get_cutlass_dtype, cutlass_to_torch_dtype, get_num_sm, make_ptr
5354
from typing import Callable, List
5455

5556

flashinfer/cute_dsl/utils.py

Lines changed: 156 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,16 @@
1414
limitations under the License.
1515
"""
1616

17+
import ctypes
18+
import functools
19+
import importlib.util
20+
from typing import Union
21+
1722
import cutlass
23+
import cutlass._mlir.dialects.cute as _cute_ir
1824
import torch
19-
import importlib.util
20-
import functools
25+
from cutlass._mlir import ir
26+
from cutlass.cute.typing import AddressSpace, Numeric, Pointer, Type
2127

2228

2329
def is_cute_dsl_available() -> bool:
@@ -67,3 +73,151 @@ def cutlass_to_torch_dtype(cutlass_dtype):
6773
def get_num_sm(device: torch.device) -> int:
6874
# get the compute capability of the device, which would be cached
6975
return torch.cuda.get_device_properties(device).multi_processor_count
76+
77+
78+
# WAR for CuTeDSL make_ptr implementation for flashinfer
79+
class _Pointer(Pointer):
80+
"""Runtime representation of a pointer that can inter-operate with
81+
various data structures, including numpy arrays and device memory.
82+
83+
:param pointer: The pointer to the data
84+
:type pointer: int or pointer-like object
85+
:param dtype: Data type of the elements pointed to
86+
:type dtype: Type
87+
:param mem_space: Memory space where the pointer resides, defaults generic
88+
:type mem_space: _cute_ir.AddressSpace, optional
89+
:param assumed_align: Alignment of input pointer in bytes, defaults None
90+
:type assumed_align: int, optional
91+
92+
:ivar _pointer: The underlying pointer
93+
:ivar _dtype: Data type of the elements
94+
:ivar _addr_space: Memory space of the pointer
95+
:ivar _assumed_align: Alignment of the pointer in bytes
96+
:ivar _desc: C-type descriptor for the pointer
97+
:ivar _c_pointer: C-compatible pointer representation
98+
"""
99+
100+
def __init__(
101+
self,
102+
pointer,
103+
dtype,
104+
mem_space: _cute_ir.AddressSpace = _cute_ir.AddressSpace.generic,
105+
assumed_align=None,
106+
):
107+
self._pointer = pointer
108+
self._dtype = dtype
109+
self._addr_space = mem_space
110+
111+
if assumed_align is None:
112+
self._assumed_align = dtype.width // 8
113+
else:
114+
self._assumed_align = assumed_align
115+
116+
self._desc = None
117+
self._c_pointer = None
118+
assert int(self._pointer) % self._assumed_align == 0, (
119+
f"pointer must be {self._assumed_align} bytes aligned"
120+
)
121+
122+
def size_in_bytes(self) -> int:
123+
return ctypes.sizeof(ctypes.c_void_p(int(self._pointer)))
124+
125+
def __get_mlir_types__(self):
126+
return [self.mlir_type]
127+
128+
def __c_pointers__(self):
129+
if self._c_pointer is None:
130+
self._desc = ctypes.c_void_p(int(self._pointer))
131+
self._c_pointer = ctypes.addressof(self._desc)
132+
return [self._c_pointer]
133+
134+
def __new_from_mlir_values__(self, values):
135+
assert len(values) == 1
136+
return values[0]
137+
138+
# Move mlir Type out of __init__ to decouple with mlir Context
139+
@property
140+
def mlir_type(self) -> ir.Type:
141+
return _cute_ir.PtrType.get(
142+
self._dtype.mlir_type, self._addr_space, self._assumed_align
143+
)
144+
145+
@property
146+
def dtype(self) -> Type[Numeric]:
147+
return self._dtype
148+
149+
@property
150+
def memspace(self):
151+
return self._addr_space
152+
153+
def align(self, min_align: int, *, loc=None, ip=None) -> Pointer:
154+
raise NotImplementedError("align is not supported in runtime")
155+
156+
def verify(self, expected_py_type):
157+
# if expected_py_type is Pointer:
158+
# return True
159+
# elif isinstance(expected_py_type, ir.Value) and expected_py_type.ty is Pointer:
160+
# return True
161+
if expected_py_type is Pointer or (
162+
isinstance(expected_py_type, ir.Value) and expected_py_type.ty is Pointer
163+
):
164+
return True
165+
166+
return False
167+
168+
def __str__(self) -> str:
169+
return f"Ptr<0x{int(self._pointer):016x}@{self._addr_space}>"
170+
171+
def __repr__(self):
172+
return self.__str__()
173+
174+
175+
def make_ptr(
176+
dtype: Type[Numeric],
177+
value: Union[int, ctypes._Pointer],
178+
mem_space: AddressSpace = AddressSpace.generic,
179+
assumed_align=None,
180+
) -> Pointer:
181+
"""Create a pointer from a memory address
182+
183+
:param dtype: Data type of the pointer elements
184+
:type dtype: Type[Numeric]
185+
:param value: Memory address as integer or ctypes pointer
186+
:type value: Union[int, ctypes._Pointer]
187+
:param mem_space: Memory address space, defaults to AddressSpace.generic
188+
:type mem_space: AddressSpace, optional
189+
:param assumed_align: Alignment in bytes, defaults to None
190+
:type assumed_align: int, optional
191+
:return: A pointer object
192+
:rtype: Pointer
193+
194+
.. code-block:: python
195+
196+
import numpy as np
197+
import ctypes
198+
199+
from cutlass import Float32
200+
from cutlass.cute.runtime import make_ptr
201+
202+
# Create a numpy array
203+
a = np.random.randn(16, 32).astype(np.float32)
204+
205+
# Get pointer address as integer
206+
ptr_address = a.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
207+
208+
# Create pointer from address
209+
y = make_ptr(cutlass.Float32, ptr_address)
210+
"""
211+
# check if value is int or ctypes.POINTER
212+
if isinstance(value, int):
213+
address_value = value
214+
elif isinstance(value, ctypes._Pointer):
215+
# get address value
216+
address_value = ctypes.cast(value, ctypes.c_void_p).value
217+
assert address_value is not None, "Pointer address is None"
218+
else:
219+
raise TypeError(
220+
f"Expect int or ctypes.POINTER for value but got {type(value)=}"
221+
)
222+
223+
return _Pointer(address_value, dtype, mem_space, assumed_align=assumed_align)

0 commit comments

Comments
 (0)