Skip to content

Commit 1c9c4a4

Browse files
committed
draf
1 parent 1e8c121 commit 1c9c4a4

File tree

3 files changed

+171
-0
lines changed

3 files changed

+171
-0
lines changed

operators/cuda/roatry_impl.cuh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include <cuda.h>
6+
#include <cuda_runtime.h>
7+
8+
enum class RotarySide : int {
9+
LEFT = 1,
10+
RIGHT = 2,
11+
};
12+
13+
template <typename T>
14+
cudaError_t LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim,
15+
const T* input, const int64_t* split_data, T* output, RotarySide side);

operators/cuda/rotary.h

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include "ocos.h"
6+
#include "rotary_impl.cuh"
7+
#include "ortx_common.h"
8+
9+
namespace contrib {
10+
11+
template <typename T>
12+
struct Rotary {
13+
template <typename TDict>
14+
OrtxStatus OnModelAttach(OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
15+
std::string side;
16+
auto status = OrtW::GetOpAttribute(info, "side", side);
17+
if (!status) {
18+
return {kOrtxErrorInvalidArgument, "Missing or wrong argument side."};
19+
}
20+
if (side == "left") {
21+
side_ = RotarySide::LEFT;
22+
}
23+
else if (side == "right") {
24+
side_ = RotarySide::RIGHT;
25+
}
26+
else {
27+
return {kOrtxErrorInvalidArgument, "side must be 'left' or 'right'."};
28+
}
29+
30+
return {};
31+
}
32+
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
33+
const ortc::Tensor<T>& input,
34+
const ortc::Tensor<int64_t>& split,
35+
ortc::Tensor<T>& output) const {
36+
const T* input_data = input.Data();
37+
auto input_shape = input.Shape();
38+
T* output_data = output.Allocate(input_shape);
39+
auto input_length = input.NumberOfElement();
40+
if (0 == input_length) {
41+
return {};
42+
}
43+
44+
auto shape_split = split.Shape();
45+
if (shape_split.size() != 1 || shape_split[0] != 2) {
46+
return {kOrtxErrorInvalidArgument, "Rotary only works when there are two sides."};
47+
}
48+
if (shape_split[0] != shape_split[1]) {
49+
return {kOrtxErrorInvalidArgument, "Only equal split are allowed."};
50+
}
51+
if (shape_split[0] * 2 != input_shape[input_shape.size()-1]) {
52+
return {kOrtxErrorInvalidArgument, "Sum of the splits are not equal to the last dimension."};
53+
}
54+
55+
const int64_t* split_data = split.Data();
56+
57+
LaunchRotaryKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
58+
input_length,
59+
static_cast<int>(input_shape[input_shape.size()-1]),
60+
input_data,
61+
split_data,
62+
output_data,
63+
side_);
64+
return {};
65+
}
66+
67+
static OrtMemType GetInputMemoryType(size_t input_index) {
68+
if (input_index == 1) // split
69+
return OrtMemType::OrtMemTypeCPUInput;
70+
return OrtMemType::OrtMemTypeDefault;
71+
}
72+
73+
private:
74+
RotarySide side_;
75+
};
76+
77+
} // namespace contrib

operators/cuda/rotary_impl.cu

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "device_prop.cuh"
5+
#include "utils.cuh"
6+
#include "Rotary_impl.cuh"
7+
#include "cuda_type.h"
8+
9+
using namespace Ort::Custom;
10+
11+
template <typename T> __device__ __inline__ T _neg(const T x) { return -x; }
12+
13+
#if __CUDA_ARCH__ < 700
14+
template <> __device__ __inline__ half _neg(const half x) {
15+
return __float2half(-__half2float(x));
16+
}
17+
#endif
18+
19+
template <typename T, RotarySide side>
20+
__global__ void RotaryKernel(T *output_data, const T *input_data, CUDA_LONG half_N, CUDA_LONG half_stride) {
21+
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
22+
if (id >= half_N)
23+
return;
24+
CUDA_LONG last = id % half_stride;
25+
id = (id - last) * 2 + last;
26+
if (side == RotarySide::RIGHT) {
27+
output_data[id + half_stride] = input_data[id];
28+
output_data[id] = _neg(input_data[id + half_stride]);
29+
} else {
30+
output_data[id + half_stride] = _neg(input_data[id]);
31+
output_data[id] = input_data[id + half_stride];
32+
}
33+
}
34+
35+
template <typename T>
36+
cudaError_t _LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim,
37+
const T* input, const int64_t* split_data, T* output, RotarySide side) {
38+
constexpr int blockSize = 256;
39+
const int gridSize = (input_length + blockSize - 1) / blockSize;
40+
if (input_length == 0)
41+
return;
42+
using TT = typename contrib::CudaT<T>::MappedType;
43+
44+
CUDA_LONG N = static_cast<CUDA_LONG>(count);
45+
CUDA_LONG stride = static_cast<CUDA_LONG>(last_dim);
46+
47+
const int num_threads_per_block = GridDim::maxThreadsPerBlock;
48+
const int num_elements_per_thread =
49+
(N / 2 + num_threads_per_block - 1) / num_threads_per_block;
50+
51+
switch (side) {
52+
case RotarySide::LEFT:
53+
RotaryKernel<T, RotarySide::LEFT>
54+
<<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(output_data, input_data,
55+
N / 2, stride / 2);
56+
break;
57+
case RotarySide::RIGHT:
58+
RotaryKernel<T, RotarySide::RIGHT>
59+
<<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(output_data, input_data,
60+
N / 2, stride / 2);
61+
break;
62+
}
63+
64+
RotaryKernel<TT><<<gridSize, blockSize, 0, stream>>>(reinterpret_cast<TT*>(output), reinterpret_cast<const TT*>(input), input_length);
65+
return cudaGetLastError();
66+
}
67+
68+
template <>
69+
cudaError_t LaunchRotaryKernel<float>(cudaStream_t stream, int input_length, int last_dim,
70+
const float* input, const int64_t* split_data, float* output, RotarySide side) {
71+
return _LaunchRotaryKernel(stream, input_length, last_dim, input, split_data, output, side);
72+
}
73+
74+
template <>
75+
cudaError_t LaunchRotaryKernel<ortc::MFloat16>(cudaStream_t stream, int input_length, int last_dim,
76+
const ortc::MFloat16* input, const int64_t* split_data,
77+
ortc::MFloat16* output, RotarySide side) {
78+
return _LaunchRotaryKernel(stream, input_length, last_dim, input, split_data, output, side);
79+
}

0 commit comments

Comments
 (0)