-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.mojo
More file actions
28 lines (22 loc) · 874 Bytes
/
main.mojo
File metadata and controls
28 lines (22 loc) · 874 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import triton_lite as tl
from kernels.test import test_kernel
from kernels.make_4d_causal_mask import make_4d_causal_mask_kernel
from kernels.apply_rotary_emb import cos_rotate, sin_rotate
from gpu.host import DeviceContext
from memory import UnsafePointer
from gpu.host._compile import _compile_code, _to_sass
fn main() raises:
# with DeviceContext() as ctx:
# dev_ptr = ctx.enqueue_create_buffer[DType.uint32](512)
# ctx.enqueue_function[test_kernel[128]](
# dev_ptr, 1, 512, grid_dim=4, block_dim=128
# )
# host_ptr = UnsafePointer[UInt32].alloc(512)
# ctx.enqueue_copy(host_ptr, dev_ptr)
# ctx.synchronize()
# for i in range(512):
# print(host_ptr[i])
# host_ptr.free()
func = _compile_code[make_4d_causal_mask_kernel[128]]()
print(func.asm)
# print(_to_sass(func.asm))