33from algorithm import vectorize, parallelize
44from memory.memory import _malloc, stack_allocation
55from sys import CompilationTarget, num_performance_cores, simd_width_of, size_of
6- import benchmark
7- from testing import assert_equal
86from utils import IndexList
97import random
108from .params import *
@@ -37,11 +35,11 @@ struct Layout(Copyable, Movable, Writable):
3735 var shape : IndexList[2 ]
3836 var strides : IndexList[2 ]
3937
40- fn __init__ (out self , shape : ( Int, Int) , strides : ( Int, Int) ):
38+ fn __init__ (out self , shape : Tuple[ Int, Int] , strides : Tuple[ Int, Int] ):
4139 self .shape = IndexList[2 ](shape[0 ], shape[1 ])
4240 self .strides = IndexList[2 ](strides[0 ], strides[1 ])
4341
44- fn __init__ (out self , shape : ( Int, Int) ):
42+ fn __init__ (out self , shape : Tuple[ Int, Int] ):
4543 self .strides = IndexList[2 ](shape[1 ], 1 )
4644 self .shape = IndexList[2 ](shape[0 ], shape[1 ])
4745
@@ -59,31 +57,31 @@ struct Layout(Copyable, Movable, Writable):
5957
6058
6159struct Matrix[Type: DType]:
62- var data : UnsafePointer[Scalar[Type]]
60+ var data : UnsafePointer[Scalar[Type], MutAnyOrigin ]
6361 var layout : Layout
6462
65- fn __init__ (out self , shape : ( Int, Int) ):
66- self .data = UnsafePointer [Scalar[Type]].alloc (shape[0 ] * shape[1 ])
63+ fn __init__ (out self , shape : Tuple[ Int, Int] ):
64+ self .data = alloc [Scalar[Type]](shape[0 ] * shape[1 ])
6765 self .layout = Layout(shape)
6866
6967 @always_inline (" nodebug" )
7068 fn __init__ (
71- out self , data : UnsafePointer[Scalar[Type]], var layout : Layout
69+ out self , data : UnsafePointer[Scalar[Type], MutAnyOrigin ], var layout : Layout
7270 ):
73- self .data = UnsafePointer[Scalar[Type]]( data)
71+ self .data = data
7472 self .layout = layout
7573
7674 @always_inline (" nodebug" )
7775 fn __init__ (
78- out self , data : UnsafePointer[Scalar[Type]], shape : ( Int, Int)
76+ out self , data : UnsafePointer[Scalar[Type], MutAnyOrigin ], shape : Tuple[ Int, Int]
7977 ):
8078 self .data = data
8179 self .layout = Layout(shape)
8280
8381 @always_inline (" nodebug" )
8482 fn __getitem__ (
8583 ref [_]self , i : Int, j : Int
86- ) -> ref [__origin_of (self )] Scalar[Type]:
84+ ) -> ref [origin_of (self )] Scalar[Type]:
8785 var offset = self .layout(i, j)
8886 return (self .data + offset)[]
8987
@@ -146,7 +144,7 @@ struct Matrix[Type: DType]:
146144@always_inline
147145fn pack_A [
148146 Type : DType, //, mr : Int
149- ](mc : Int, Ac_buffer : UnsafePointer[Scalar[Type]], Ac : Matrix[Type]) -> Matrix[Type]:
147+ ](mc : Int, Ac_buffer : UnsafePointer[Scalar[Type], MutAnyOrigin ], Ac : Matrix[Type]) -> Matrix[Type]:
150148 @parameter
151149 fn pack_panel (idx : Int):
152150 var i = idx * mr
@@ -184,7 +182,7 @@ fn pack_A[
184182@always_inline
185183fn pack_B [
186184 Type : DType, //, kc : Int, nr : Int
187- ](Bc_buffer : UnsafePointer[Scalar[Type]], Bc : Matrix[Type]) -> Matrix[Type]:
185+ ](Bc_buffer : UnsafePointer[Scalar[Type], MutAnyOrigin ], Bc : Matrix[Type]) -> Matrix[Type]:
188186 var dst_ptr = Bc_buffer
189187 for i in range (0 , Bc.shape[1 ](), nr):
190188 var src_ptr = Bc.data + i
@@ -267,7 +265,7 @@ fn loop_n[
267265
268266 @parameter
269267 fn parallelize_balanced_part (idx : Int):
270- var Bc_buffer = UnsafePointer[Scalar[Type]](
268+ var Bc_buffer = UnsafePointer[Scalar[Type], MutAnyOrigin ](
271269 _malloc[Scalar[Type]](
272270 kc * nc_per_thread * size_of[Type](), alignment = 64
273271 )
@@ -290,7 +288,7 @@ fn loop_n[
290288
291289 @parameter
292290 fn parallelize_remainder (idx : Int):
293- var Bc_buffer = UnsafePointer[Scalar[Type]](
291+ var Bc_buffer = UnsafePointer[Scalar[Type], MutAnyOrigin ](
294292 _malloc[Scalar[Type]](
295293 kc * remainder_per_thread * size_of[Type](), alignment = 64
296294 )
@@ -348,7 +346,7 @@ fn macro_kernel[
348346fn micro_kernel [
349347 Type : DType, //, mr : Int, nr : Int, padding : Bool
350348](mut Cr : Matrix[Type], Ar : Matrix[Type], Br : Matrix[Type]):
351- alias simd_width = simd_width_of[Type]()
349+ comptime simd_width = simd_width_of[Type]()
352350 constrained[nr % simd_width == 0 , " nr must be multiple of simd_width" ]()
353351
354352 var Ar_ptr = Ar.data
@@ -440,31 +438,31 @@ fn micro_kernel[
440438
441439@always_inline
442440fn matmul_params [Type : DType]() -> IndexList[5 ]:
443- alias mc = 8192 // size_of[Type]() # fix this for simplicity
444- alias N = simd_width_of[Type]()
441+ comptime mc = 8192 // size_of[Type]() # fix this for simplicity
442+ comptime N = simd_width_of[Type]()
445443
446- alias Vectors = 32 if CompilationTarget.has_avx512f() else 16
444+ comptime Vectors = 32 if CompilationTarget.has_avx512f() else 16
447445
448446 @parameter
449447 fn compute_kc [mr : Int, nr : Int]() -> Int:
450- alias CBr = Int((L1_ASSOCIATIVITY - 1 ) / (1 + mr / nr))
448+ comptime CBr = Int((L1_ASSOCIATIVITY - 1 ) / (1 + mr / nr))
451449 return (CBr * L1_CACHE_SIZE ) // (nr * size_of[Type]() * L1_ASSOCIATIVITY )
452450
453451 @parameter
454452 fn compute_params [C : Int]() -> IndexList[5 ]:
455- alias p = C // (intsqrt[C]() + 1 )
456- alias mr = C // p - 1
457- alias nr = p * N
458- alias CBr = Int((L1_ASSOCIATIVITY - 1 ) / (1 + mr / nr))
459- alias kc = compute_kc[mr, nr]()
460- alias nc = (L2_ASSOCIATIVITY - 1 ) * L2_CACHE_SIZE // (
453+ comptime p = C // (intsqrt[C]() + 1 )
454+ comptime mr = C // p - 1
455+ comptime nr = p * N
456+ comptime CBr = Int((L1_ASSOCIATIVITY - 1 ) / (1 + mr / nr))
457+ comptime kc = compute_kc[mr, nr]()
458+ comptime nc = (L2_ASSOCIATIVITY - 1 ) * L2_CACHE_SIZE // (
461459 kc * size_of[Type]() * L2_ASSOCIATIVITY
462460 ) - mr
463461 return IndexList[5 ](mc, nc, kc, mr, nr)
464462
465463 @parameter
466464 if Type.is_floating_point():
467- alias TempVectors = 1
465+ comptime TempVectors = 1
468466 return compute_params[Vectors - TempVectors]()
469467 else :
470468
@@ -473,25 +471,25 @@ fn matmul_params[Type: DType]() -> IndexList[5]:
473471
474472 @parameter
475473 if CompilationTarget.has_avx512f():
476- alias TempVectors = 2
474+ comptime TempVectors = 2
477475 return compute_params[Vectors - TempVectors]()
478476 else :
479- alias TempVectors = 3
477+ comptime TempVectors = 3
480478 return compute_params[Vectors - TempVectors]()
481479 else :
482- alias TempVectors = 2
480+ comptime TempVectors = 2
483481 return compute_params[Vectors - TempVectors]()
484482
485483
486484fn matmul [
487485 Type : DType
488486](m : Int, n : Int, k : Int, mut C : Matrix[Type], A : Matrix[Type], B : Matrix[Type]):
489- alias params = matmul_params[Type]()
490- alias mc = params[0 ]
491- alias nc = params[1 ]
492- alias kc = params[2 ]
493- alias mr = params[3 ]
494- alias nr = params[4 ]
487+ comptime params = matmul_params[Type]()
488+ comptime mc = params[0 ]
489+ comptime nc = params[1 ]
490+ comptime kc = params[2 ]
491+ comptime mr = params[3 ]
492+ comptime nr = params[4 ]
495493 var resized_mc = roundup(min (mc, m), mr)
496494 var resized_nc = roundup(min (nc, n), nr)
497495 matmul_impl[kc, mr, nr](resized_mc, resized_nc, C, A, B)
0 commit comments