@@ -114,19 +114,23 @@ class Tensor(NamedTuple):
114114
115115
116116@triton .jit
117- def _namedtuple_kernel ( closure , X , Y , BLOCK_M : tl .constexpr , BLOCK_N : tl .constexpr ):
117+ def _namedtuple_mask_func ( Tensor , BLOCK_M : tl .constexpr , BLOCK_N : tl .constexpr ):
118118 offs_m = tl .arange (0 , BLOCK_M )
119119 offs_n = tl .arange (0 , BLOCK_N )
120- # load x
121- mask_x = (offs_m [:, None ] < X .shape [0 ]) & (offs_n [None , :] < X .shape [1 ])
120+ mask = (offs_m [:, None ] < Tensor .shape [0 ]) & (offs_n [None , :] < Tensor .shape [1 ])
121+ return mask
122+
123+
124+ @triton .jit
125+ def _namedtuple_kernel (closure , _X , Y , BLOCK_M : tl .constexpr , BLOCK_N : tl .constexpr ):
126+ offs_m = tl .arange (0 , BLOCK_M )
127+ offs_n = tl .arange (0 , BLOCK_N )
128+ X = Tensor (shape = _X .shape , ptr = _X .ptr , stride = _X .stride )
122129 Xs = X .ptr + offs_m [:, None ] * X .stride [0 ] + offs_n [None , :] * X .stride [1 ]
123- x = tl .load (Xs , mask = mask_x , other = 0 )
124- # compute y
125- y = closure .fn (x , * closure .captured )
126- # store y
127- mask_y = (offs_m [:, None ] < Y .shape [0 ]) & (offs_n [None , :] < Y .shape [1 ])
128130 Ys = Y .ptr + offs_m [:, None ] * Y .stride [0 ] + offs_n [None , :] * Y .stride [1 ]
129- tl .store (Ys , y , mask = mask_y )
131+ x = tl .load (Xs , mask = _namedtuple_mask_func (X , BLOCK_M , BLOCK_N ), other = 0 )
132+ y = closure .fn (x , * closure .captured )
133+ tl .store (Ys , y , mask = _namedtuple_mask_func (Y , BLOCK_M , BLOCK_N ))
130134
131135
132136def test_namedtuple (device = "cuda" ):
0 commit comments