Skip to content

Commit 361b4a8

Browse files
committed
issue/931 - ninetoothed swiglu
1 parent 6f8a443 commit 361b4a8

File tree

4 files changed

+189
-0
lines changed

4 files changed

+189
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import ninetoothed
2+
from . import swiglu
3+
4+
import infiniop.ninetoothed.build
5+
6+
7+
def build():
8+
MAX_NDIM = 5
9+
10+
ndim_values = range(1, MAX_NDIM + 1)
11+
dtype_values = (
12+
ninetoothed.float16,
13+
ninetoothed.bfloat16,
14+
ninetoothed.float32,
15+
)
16+
17+
constexpr_param_grid = {
18+
"ndim": ndim_values,
19+
"dtype": dtype_values,
20+
"block_size": (1024,),
21+
}
22+
23+
infiniop.ninetoothed.build.build(
24+
swiglu.premake,
25+
constexpr_param_grid,
26+
caller="cuda",
27+
op_name="swiglu",
28+
output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH,
29+
)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#ifndef SWIGLU_H
2+
#define SWIGLU_H
3+
4+
#include "../../../handle.h"
5+
#include "../../../operator.h"
6+
#include "../../../tensor.h"
7+
8+
#include "../../../../../build/ninetoothed/swiglu.h"
9+
#include "../../../ninetoothed/utils.h"
10+
11+
namespace op::swiglu::ninetoothed {
12+
class Descriptor final : public InfiniopDescriptor {
13+
14+
public:
15+
Descriptor(
16+
infiniopHandle_t handle,
17+
infiniopTensorDescriptor_t out_desc,
18+
std::vector<infiniopTensorDescriptor_t> input_desc_vec) : InfiniopDescriptor{handle->device, handle->device_id},
19+
out_shape_{out_desc->shape()},
20+
out_strides_{out_desc->strides()},
21+
up_shape_{input_desc_vec[0]->shape()},
22+
up_strides_{input_desc_vec[0]->strides()},
23+
gate_shape_{input_desc_vec[1]->shape()},
24+
gate_strides_{input_desc_vec[1]->strides()},
25+
dtype_{out_desc->dtype()} {}
26+
27+
~Descriptor() = default;
28+
29+
size_t workspaceSize() const {
30+
return 0;
31+
}
32+
33+
static infiniStatus_t create(
34+
infiniopHandle_t handle,
35+
Descriptor **desc_ptr,
36+
infiniopTensorDescriptor_t out_desc,
37+
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
38+
*desc_ptr = new Descriptor(handle, out_desc, input_desc_vec);
39+
return INFINI_STATUS_SUCCESS;
40+
}
41+
42+
infiniStatus_t calculate(
43+
void *workspace,
44+
size_t workspace_size,
45+
void *output,
46+
std::vector<const void *> inputs,
47+
void *stream) const {
48+
auto out_nt{::ninetoothed::Tensor(output, out_shape_, out_strides_)};
49+
auto up_nt{::ninetoothed::Tensor(inputs[0], up_shape_, up_strides_)};
50+
auto gate_nt{::ninetoothed::Tensor(inputs[1], gate_shape_, gate_strides_)};
51+
52+
if (launch_swiglu(stream,
53+
out_nt,
54+
up_nt,
55+
gate_nt,
56+
out_shape_.size(),
57+
dtype_,
58+
1024)) {
59+
return INFINI_STATUS_NOT_IMPLEMENTED;
60+
}
61+
62+
return INFINI_STATUS_SUCCESS;
63+
}
64+
65+
private:
66+
using Size = ::ninetoothed::Tensor<>::Size;
67+
using Stride = ::ninetoothed::Tensor<>::Stride;
68+
69+
std::vector<Size> out_shape_;
70+
std::vector<Stride> out_strides_;
71+
72+
std::vector<Size> up_shape_;
73+
std::vector<Stride> up_strides_;
74+
75+
std::vector<Size> gate_shape_;
76+
std::vector<Stride> gate_strides_;
77+
78+
infiniDtype_t dtype_;
79+
};
80+
} // namespace op::swiglu::ninetoothed
81+
82+
#endif // SWIGLU_H
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import functools
2+
3+
import ninetoothed.language as ntl
4+
from ninetoothed import Tensor
5+
6+
from ntops.kernels.element_wise import arrangement
7+
8+
9+
def application(output, up, gate):
10+
output = ntl.sigmoid(ntl.cast(gate, ntl.float32)) * gate * up # noqa: F841
11+
12+
13+
def premake(ndim, dtype=None, block_size=None):
14+
arrangement_ = functools.partial(arrangement, block_size=block_size)
15+
16+
tensors = (
17+
Tensor(ndim, dtype=dtype),
18+
Tensor(ndim, dtype=dtype),
19+
Tensor(ndim, dtype=dtype),
20+
)
21+
22+
return arrangement_, application, tensors

src/infiniop/ops/swiglu/operator.cc

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,22 @@
66
#include "cpu/swiglu_cpu.h"
77
#endif
88
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
9+
#if defined(ENABLE_NINETOOTHED)
10+
#include "ninetoothed/swiglu.h"
11+
#else
912
#include "nvidia/swiglu_nvidia.cuh"
1013
#endif
14+
#endif
1115
#ifdef ENABLE_KUNLUN_API
1216
#include "kunlun/swiglu_kunlun.h"
1317
#endif
1418
#ifdef ENABLE_METAX_API
19+
#if defined(ENABLE_NINETOOTHED)
20+
#include "ninetoothed/swiglu.h"
21+
#else
1522
#include "metax/swiglu_metax.h"
1623
#endif
24+
#endif
1725
#ifdef ENABLE_CAMBRICON_API
1826
#include "bang/swiglu_bang.h"
1927
#endif
@@ -46,11 +54,19 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
4654
CREATE(INFINI_DEVICE_CPU, cpu);
4755
#endif
4856
#ifdef ENABLE_NVIDIA_API
57+
#ifdef ENABLE_NINETOOTHED
58+
CREATE(INFINI_DEVICE_NVIDIA, ninetoothed);
59+
#else
4960
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
5061
#endif
62+
#endif
5163
#ifdef ENABLE_ILUVATAR_API
64+
#ifdef ENABLE_NINETOOTHED
65+
CREATE(INFINI_DEVICE_ILUVATAR, ninetoothed);
66+
#else
5267
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
5368
#endif
69+
#endif
5470
#ifdef ENABLE_QY_API
5571
CREATE(INFINI_DEVICE_QY, nvidia);
5672
#endif
@@ -61,8 +77,12 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
6177
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
6278
#endif
6379
#ifdef ENABLE_METAX_API
80+
#ifdef ENABLE_NINETOOTHED
81+
CREATE(INFINI_DEVICE_METAX, ninetoothed);
82+
#else
6483
CREATE(INFINI_DEVICE_METAX, metax);
6584
#endif
85+
#endif
6686
#ifdef ENABLE_CAMBRICON_API
6787
CREATE(INFINI_DEVICE_CAMBRICON, bang);
6888
#endif
@@ -92,11 +112,19 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des
92112
GET(INFINI_DEVICE_CPU, cpu);
93113
#endif
94114
#ifdef ENABLE_NVIDIA_API
115+
#ifdef ENABLE_NINETOOTHED
116+
GET(INFINI_DEVICE_NVIDIA, ninetoothed);
117+
#else
95118
GET(INFINI_DEVICE_NVIDIA, nvidia);
96119
#endif
120+
#endif
97121
#ifdef ENABLE_ILUVATAR_API
122+
#ifdef ENABLE_NINETOOTHED
123+
GET(INFINI_DEVICE_ILUVATAR, ninetoothed);
124+
#else
98125
GET(INFINI_DEVICE_ILUVATAR, nvidia);
99126
#endif
127+
#endif
100128
#ifdef ENABLE_QY_API
101129
GET(INFINI_DEVICE_QY, nvidia);
102130
#endif
@@ -107,8 +135,12 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des
107135
GET(INFINI_DEVICE_KUNLUN, kunlun);
108136
#endif
109137
#ifdef ENABLE_METAX_API
138+
#ifdef ENABLE_NINETOOTHED
139+
GET(INFINI_DEVICE_METAX, ninetoothed);
140+
#else
110141
GET(INFINI_DEVICE_METAX, metax);
111142
#endif
143+
#endif
112144
#ifdef ENABLE_CAMBRICON_API
113145
GET(INFINI_DEVICE_CAMBRICON, bang);
114146
#endif
@@ -145,11 +177,19 @@ __C infiniStatus_t infiniopSwiGLU(
145177
CALCULATE(INFINI_DEVICE_CPU, cpu);
146178
#endif
147179
#ifdef ENABLE_NVIDIA_API
180+
#ifdef ENABLE_NINETOOTHED
181+
CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed);
182+
#else
148183
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
149184
#endif
185+
#endif
150186
#ifdef ENABLE_ILUVATAR_API
187+
#ifdef ENABLE_NINETOOTHED
188+
CALCULATE(INFINI_DEVICE_ILUVATAR, ninetoothed);
189+
#else
151190
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
152191
#endif
192+
#endif
153193
#ifdef ENABLE_QY_API
154194
CALCULATE(INFINI_DEVICE_QY, nvidia);
155195
#endif
@@ -160,8 +200,12 @@ __C infiniStatus_t infiniopSwiGLU(
160200
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
161201
#endif
162202
#ifdef ENABLE_METAX_API
203+
#ifdef ENABLE_NINETOOTHED
204+
CALCULATE(INFINI_DEVICE_METAX, ninetoothed);
205+
#else
163206
CALCULATE(INFINI_DEVICE_METAX, metax);
164207
#endif
208+
#endif
165209
#ifdef ENABLE_CAMBRICON_API
166210
CALCULATE(INFINI_DEVICE_CAMBRICON, bang);
167211
#endif
@@ -193,11 +237,19 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) {
193237
DELETE(INFINI_DEVICE_CPU, cpu);
194238
#endif
195239
#ifdef ENABLE_NVIDIA_API
240+
#ifdef ENABLE_NINETOOTHED
241+
DELETE(INFINI_DEVICE_NVIDIA, ninetoothed);
242+
#else
196243
DELETE(INFINI_DEVICE_NVIDIA, nvidia);
197244
#endif
245+
#endif
198246
#ifdef ENABLE_ILUVATAR_API
247+
#ifdef ENABLE_NINETOOTHED
248+
DELETE(INFINI_DEVICE_ILUVATAR, ninetoothed);
249+
#else
199250
DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
200251
#endif
252+
#endif
201253
#ifdef ENABLE_QY_API
202254
DELETE(INFINI_DEVICE_QY, nvidia);
203255
#endif
@@ -208,8 +260,12 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) {
208260
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
209261
#endif
210262
#ifdef ENABLE_METAX_API
263+
#ifdef ENABLE_NINETOOTHED
264+
DELETE(INFINI_DEVICE_METAX, ninetoothed);
265+
#else
211266
DELETE(INFINI_DEVICE_METAX, metax);
212267
#endif
268+
#endif
213269
#ifdef ENABLE_CAMBRICON_API
214270
DELETE(INFINI_DEVICE_CAMBRICON, bang);
215271
#endif

0 commit comments

Comments
 (0)