@@ -54,6 +54,20 @@ struct Mix {
5454 __hmul (sx_, __hsub (__float2half (1 ), r_mix_)));
5555 }
5656};
57+
58+ struct ToHalf {
59+ const float *x;
60+ half *y;
61+ __device__ void operator ()(int i) const { y[i] = __float2half (x[i]); }
62+ };
63+
64+ struct InplaceAdd {
65+ __device__ __forceinline__ half operator ()(int i) const {
66+ y[i] = __hadd (x[i], y[i]);
67+ }
68+ half *y;
69+ half *x;
70+ };
5771} // namespace
5872
5973using torch::Tensor;
@@ -64,50 +78,44 @@ void gemm_cublas(const void *a, const void *b, void *c, int batch, int ori_m,
6478 at::ScalarType torch_output_dtype);
6579
6680Tensor att_one_v5 (Tensor x, Tensor sx, Tensor s, Tensor ln_w, Tensor ln_b,
67- Tensor lx_w, Tensor lx_b, Tensor kvr_mix,
68- /* imm */ Tensor kvrx, Tensor kvrw, Tensor ow, Tensor t_first,
69- Tensor t_decay, /* imm */ Tensor kvr, /* imm */ Tensor a,
70- /* imm */ Tensor buf,
71- /* imm */ Tensor s1,
72- /* out */ Tensor x_plus_out, /* out */ Tensor s2) {
81+ Tensor lx_w, Tensor lx_b, Tensor kvr_mix, Tensor kvrw,
82+ Tensor ow, Tensor t_first, Tensor t_decay, Tensor tmp,
83+ Tensor buf, /* out */ Tensor s2_t ,
84+ /* out */ Tensor x_plus_out_t ) {
7385 const int x_numel = x.numel ();
7486 Tensor xx = at::layer_norm (x, {x_numel}, ln_w, ln_b);
87+ int H = t_decay.size (0 );
88+ int S = x_numel / H;
89+ char *buf_ptr = (char *)buf.data_ptr ();
90+ half *kvrx = (half *)buf_ptr;
91+ float *kvr = (float *)(kvrx + 3 * x_numel);
92+ float *a = kvr + 3 * x_numel;
93+ half *tmp2 = (half *)(a + H * S * S);
94+ float *s1 = (float *)(tmp2 + x_numel);
95+ float *s2 = data_ptr<float >(s2_t );
96+ half *x_plus_out = data_ptr<half>(x_plus_out_t );
97+
7598 element_wise (Mix{data_ptr<half>(xx), data_ptr<half>(sx),
76- data_ptr<half>(kvr_mix), static_cast <int >(x_numel),
77- data_ptr<half>(kvrx)},
99+ data_ptr<half>(kvr_mix), static_cast <int >(x_numel), kvrx},
78100 x_numel);
79101
80- int H = t_decay.size (0 );
81- int S = x_numel / H;
82- // gemm_cublas_tensor(at::unsqueeze(kvrx, 1), kvrw, kvr);
83- gemm_cublas (data_ptr<half>(kvrx), data_ptr<half>(kvrw), data_ptr<float >(kvr),
84- 3 , 1 , x_numel, x_numel, at::kHalf , at::kFloat );
85- float * k = data_ptr<float >(kvr);
86- float * v = k + x_numel;
87- float * r = v + x_numel;
88- // Tensor k = at::reshape(kvr[0], {H, S, 1});
89- // Tensor v = at::reshape(kvr[1], {H, 1, S});
90- // Tensor r = at::reshape(kvr[2], {H, 1, S});
91-
92- // gemm_cublas_tensor(k, v, a);
93- gemm_cublas (k, v, data_ptr<float >(a), H, S, S, 1 , at::kFloat , at::kFloat );
94- // s1 = t_first * a + s
95- // s2 = a + t_decay * s
96- element_wise (Fused1{data_ptr<float >(t_first), data_ptr<float >(t_decay),
97- data_ptr<float >(a), data_ptr<float >(s),
98- static_cast <int32_t >(a.size (1 ) * a.size (2 )),
99- data_ptr<float >(s1), data_ptr<float >(s2)},
100- a.numel ());
101-
102- // gemm_cublas_tensor(r, s1, buf);
103- gemm_cublas (r, data_ptr<float >(s1), data_ptr<float >(buf), H, 1 , S, S,
104- at::kFloat , at::kFloat );
105- buf = at::group_norm (buf, H, lx_w, lx_b);
106- buf = at::_cast_Half (buf);
107-
108- // gemm_cublas_tensor(buf, ow, x_plus_out);
109- gemm_cublas (data_ptr<half>(buf), data_ptr<half>(ow), data_ptr<half>(x_plus_out),
110- 1 , 1 , x_numel, x_numel, at::kHalf , at::kHalf );
111- x_plus_out += x;
102+ gemm_cublas (kvrx, data_ptr<half>(kvrw), kvr, 3 , 1 , x_numel, x_numel,
103+ at::kHalf , at::kFloat );
104+ float *k = kvr;
105+ float *v = k + x_numel;
106+ float *r = v + x_numel;
107+
108+ gemm_cublas (k, v, a, H, S, S, 1 , at::kFloat , at::kFloat );
109+ element_wise (Fused1{data_ptr<float >(t_first), data_ptr<float >(t_decay), a,
110+ data_ptr<float >(s), static_cast <int32_t >(S * S), s1, s2},
111+ H * S * S);
112+
113+ gemm_cublas (r, s1, data_ptr<float >(tmp), H, 1 , S, S, at::kFloat , at::kFloat );
114+ tmp = at::group_norm (tmp, H, lx_w, lx_b);
115+ element_wise (ToHalf{data_ptr<float >(tmp), tmp2}, tmp.numel ());
116+
117+ gemm_cublas (tmp2, data_ptr<half>(ow), x_plus_out, 1 , 1 , x_numel, x_numel,
118+ at::kHalf , at::kHalf );
119+ element_wise (InplaceAdd{x_plus_out, data_ptr<half>(x)}, x.numel ());
112120 return xx;
113121}
0 commit comments