Skip to content

Commit 3d34d33

Browse files
committed
Revert FUSED_RMS_NORM
Revert "Up Fused RMS Norm"
1 parent ebc7744 commit 3d34d33

File tree

10 files changed

+3
-324
lines changed

10 files changed

+3
-324
lines changed

ggml/include/ggml.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,6 @@ extern "C" {
474474
GGML_OP_RMS_NORM,
475475
GGML_OP_RMS_NORM_BACK,
476476
GGML_OP_GROUP_NORM,
477-
GGML_OP_FUSED_RMS_NORM,
478477

479478
GGML_OP_MUL_MAT,
480479
GGML_OP_MUL_MAT_ID,
@@ -1109,18 +1108,6 @@ extern "C" {
11091108
struct ggml_tensor * a,
11101109
float eps);
11111110

1112-
GGML_API struct ggml_tensor * ggml_fused_rms_norm(
1113-
struct ggml_context * ctx,
1114-
struct ggml_tensor * a,
1115-
struct ggml_tensor * b,
1116-
float eps);
1117-
1118-
GGML_API struct ggml_tensor * ggml_fused_rms_norm_inplace(
1119-
struct ggml_context * ctx,
1120-
struct ggml_tensor * a,
1121-
struct ggml_tensor * b,
1122-
float eps);
1123-
11241111
// group normalize along ne0*ne1*n_groups
11251112
// used in stable-diffusion
11261113
GGML_API struct ggml_tensor * ggml_group_norm(

ggml/src/ggml-alloc.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
5353
case GGML_OP_UNARY:
5454
case GGML_OP_ROPE:
5555
case GGML_OP_RMS_NORM:
56-
case GGML_OP_FUSED_RMS_NORM:
5756
case GGML_OP_SOFT_MAX:
5857
return true;
5958

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -7157,78 +7157,6 @@ static void ggml_compute_forward_rms_norm(
71577157
}
71587158
}
71597159

7160-
static void ggml_compute_forward_fused_rms_norm_f32(
7161-
const struct ggml_compute_params * params,
7162-
struct ggml_tensor * dst) {
7163-
7164-
const struct ggml_tensor * src0 = dst->src[0];
7165-
const struct ggml_tensor * src1 = dst->src[1];
7166-
7167-
if (!src1) {
7168-
ggml_compute_forward_rms_norm_f32(params, dst);
7169-
return;
7170-
}
7171-
7172-
GGML_ASSERT(ggml_are_same_shape(src0, dst));
7173-
7174-
GGML_ASSERT(src0->nb[0] == sizeof(float));
7175-
GGML_ASSERT(src1->nb[0] == sizeof(float));
7176-
GGML_ASSERT(src1->ne[0] == src0->ne[0]);
7177-
GGML_ASSERT(ggml_nrows(src1) == 1);
7178-
7179-
const int ith = params->ith;
7180-
const int nth = params->nth;
7181-
7182-
GGML_TENSOR_UNARY_OP_LOCALS
7183-
7184-
float eps;
7185-
memcpy(&eps, dst->op_params, sizeof(float));
7186-
7187-
GGML_ASSERT(eps > 0.0f);
7188-
7189-
// TODO: optimize
7190-
for (int64_t i03 = 0; i03 < ne03; i03++) {
7191-
for (int64_t i02 = 0; i02 < ne02; i02++) {
7192-
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
7193-
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7194-
7195-
ggml_float sum = 0.0;
7196-
for (int64_t i00 = 0; i00 < ne00; i00++) {
7197-
sum += (ggml_float)(x[i00] * x[i00]);
7198-
}
7199-
7200-
const float mean = sum/ne00;
7201-
7202-
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
7203-
7204-
const float scale = 1.0f/sqrtf(mean + eps);
7205-
7206-
ggml_vec_mul_f32(ne00, y, x, (const float *)src1->data);
7207-
ggml_vec_scale_f32(ne00, y, scale);
7208-
7209-
}
7210-
}
7211-
}
7212-
}
7213-
7214-
static void ggml_compute_forward_fused_rms_norm(
7215-
const struct ggml_compute_params * params,
7216-
struct ggml_tensor * dst) {
7217-
7218-
const struct ggml_tensor * src0 = dst->src[0];
7219-
7220-
switch (src0->type) {
7221-
case GGML_TYPE_F32:
7222-
{
7223-
ggml_compute_forward_fused_rms_norm_f32(params, dst);
7224-
} break;
7225-
default:
7226-
{
7227-
GGML_ABORT("fatal error");
7228-
}
7229-
}
7230-
}
7231-
72327160
static void ggml_compute_forward_rms_norm_back_f32(
72337161
const struct ggml_compute_params * params,
72347162
struct ggml_tensor * dst) {
@@ -13019,7 +12947,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1301912947
case GGML_OP_DIV:
1302012948
case GGML_OP_NORM:
1302112949
case GGML_OP_RMS_NORM:
13022-
case GGML_OP_FUSED_RMS_NORM:
1302312950
case GGML_OP_RMS_NORM_BACK:
1302412951
case GGML_OP_GROUP_NORM:
1302512952
case GGML_OP_CONCAT:

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2253,9 +2253,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22532253
case GGML_OP_RMS_NORM:
22542254
ggml_cuda_op_rms_norm(ctx, dst);
22552255
break;
2256-
case GGML_OP_FUSED_RMS_NORM:
2257-
ggml_cuda_op_fused_rms_norm(ctx, dst);
2258-
break;
22592256
case GGML_OP_MUL_MAT:
22602257
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
22612258
GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
@@ -3140,9 +3137,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
31403137
case GGML_OP_RMS_NORM:
31413138
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
31423139
break;
3143-
case GGML_OP_FUSED_RMS_NORM:
3144-
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
3145-
break;
31463140
case GGML_OP_NONE:
31473141
case GGML_OP_RESHAPE:
31483142
case GGML_OP_VIEW:

ggml/src/ggml-cuda/norm.cu

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -131,40 +131,6 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
131131
}
132132
}
133133

134-
template <int block_size>
135-
static __global__ void fused_rms_norm_f32(const float * x, const float * y, float * dst, const int ncols, const float eps) {
136-
const int row = blockIdx.x*blockDim.y + threadIdx.y;
137-
const int tid = threadIdx.x;
138-
139-
float tmp = 0.0f; // partial sum for thread in warp
140-
141-
for (int col = tid; col < ncols; col += block_size) {
142-
const float xi = x[row*ncols + col];
143-
tmp += xi * xi;
144-
}
145-
146-
// sum up partial sums
147-
tmp = warp_reduce_sum(tmp);
148-
if (block_size > WARP_SIZE) {
149-
__shared__ float s_sum[32];
150-
int warp_id = threadIdx.x / WARP_SIZE;
151-
int lane_id = threadIdx.x % WARP_SIZE;
152-
if (lane_id == 0) {
153-
s_sum[warp_id] = tmp;
154-
}
155-
__syncthreads();
156-
tmp = s_sum[lane_id];
157-
tmp = warp_reduce_sum(tmp);
158-
}
159-
160-
const float mean = tmp / ncols;
161-
const float scale = rsqrtf(mean + eps);
162-
163-
for (int col = tid; col < ncols; col += block_size) {
164-
dst[row*ncols + col] = scale * y[col] * x[row*ncols + col];
165-
}
166-
}
167-
168134
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
169135
GGML_ASSERT(ncols % WARP_SIZE == 0);
170136
if (ncols < 1024) {
@@ -197,18 +163,6 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
197163
}
198164
}
199165

200-
static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst,
201-
const int ncols, const int nrows, const float eps, cudaStream_t stream) {
202-
GGML_ASSERT(ncols % WARP_SIZE == 0);
203-
if (ncols < 1024) {
204-
const dim3 block_dims(WARP_SIZE, 1, 1);
205-
fused_rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
206-
} else {
207-
const dim3 block_dims(1024, 1, 1);
208-
fused_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
209-
}
210-
}
211-
212166
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
213167
const ggml_tensor * src0 = dst->src[0];
214168
const float * src0_d = (const float *)src0->data;
@@ -268,32 +222,3 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
268222

269223
rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
270224
}
271-
272-
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
273-
if (!dst->src[1]) {
274-
ggml_cuda_op_rms_norm(ctx, dst);
275-
return;
276-
}
277-
const ggml_tensor * src0 = dst->src[0];
278-
const ggml_tensor * src1 = dst->src[1];
279-
const float * src0_d = (const float *)src0->data;
280-
const float * src1_d = (const float *)src1->data;
281-
float * dst_d = (float *)dst->data;
282-
cudaStream_t stream = ctx.stream();
283-
284-
GGML_ASSERT(ggml_is_contiguous(src0));
285-
286-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
287-
GGML_ASSERT(src1->type == GGML_TYPE_F32);
288-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
289-
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
290-
GGML_ASSERT(ggml_nrows(src1) == 1);
291-
292-
const int64_t ne00 = src0->ne[0];
293-
const int64_t nrows = ggml_nrows(src0);
294-
295-
float eps;
296-
memcpy(&eps, dst->op_params, sizeof(float));
297-
298-
fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
299-
}

ggml/src/ggml-cuda/norm.cuh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,3 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
55
void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
66

77
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
8-
9-
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
156156
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
157157
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
158158
GGML_METAL_KERNEL_TYPE_RMS_NORM,
159-
GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM,
160159
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
161160
GGML_METAL_KERNEL_TYPE_NORM,
162161
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
@@ -1008,7 +1007,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
10081007
case GGML_OP_SUM_ROWS:
10091008
case GGML_OP_SOFT_MAX:
10101009
case GGML_OP_RMS_NORM:
1011-
case GGML_OP_FUSED_RMS_NORM:
10121010
case GGML_OP_GROUP_NORM:
10131011
return has_simdgroup_reduction;
10141012
case GGML_OP_NORM:
@@ -2638,38 +2636,6 @@ static void ggml_metal_encode_node(
26382636

26392637
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
26402638
} break;
2641-
case GGML_OP_FUSED_RMS_NORM:
2642-
{
2643-
GGML_ASSERT(ne00 % 4 == 0);
2644-
GGML_ASSERT(ggml_is_contiguous_1(src0));
2645-
GGML_ASSERT(src1->ne[0] == src0->ne[0]);
2646-
GGML_ASSERT(src1->type == GGML_TYPE_F32);
2647-
GGML_ASSERT(ggml_nrows(src1) == 1);
2648-
2649-
float eps;
2650-
memcpy(&eps, dst->op_params, sizeof(float));
2651-
2652-
int nth = 32; // SIMD width
2653-
2654-
while (nth < ne00/4 && nth < 1024) {
2655-
nth *= 2;
2656-
}
2657-
2658-
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM].pipeline;
2659-
2660-
[encoder setComputePipelineState:pipeline];
2661-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2662-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2663-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2664-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
2665-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
2666-
[encoder setBytes:&eps length:sizeof( float) atIndex:5];
2667-
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2668-
2669-
const int64_t nrows = ggml_nrows(src0);
2670-
2671-
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2672-
} break;
26732639
case GGML_OP_GROUP_NORM:
26742640
{
26752641
GGML_ASSERT(ne00 % 4 == 0);

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1445,57 +1445,6 @@ kernel void kernel_rms_norm(
14451445
}
14461446
}
14471447

1448-
kernel void kernel_fused_rms_norm(
1449-
device const void * src0,
1450-
device const void * src1,
1451-
device float * dst,
1452-
constant int64_t & ne00,
1453-
constant uint64_t & nb01,
1454-
constant float & eps,
1455-
threadgroup float * buf [[threadgroup(0)]],
1456-
uint tgpig[[threadgroup_position_in_grid]],
1457-
uint tpitg[[thread_position_in_threadgroup]],
1458-
uint sgitg[[simdgroup_index_in_threadgroup]],
1459-
uint tiisg[[thread_index_in_simdgroup]],
1460-
uint ntg[[threads_per_threadgroup]]) {
1461-
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
1462-
1463-
float4 sumf = 0;
1464-
float all_sum = 0;
1465-
1466-
// parallel sum
1467-
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
1468-
sumf += x[i00] * x[i00];
1469-
}
1470-
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
1471-
all_sum = simd_sum(all_sum);
1472-
if (ntg > N_SIMDWIDTH) {
1473-
if (sgitg == 0) {
1474-
buf[tiisg] = 0.0f;
1475-
}
1476-
1477-
threadgroup_barrier(mem_flags::mem_threadgroup);
1478-
1479-
if (tiisg == 0) {
1480-
buf[sgitg] = all_sum;
1481-
}
1482-
1483-
threadgroup_barrier(mem_flags::mem_threadgroup);
1484-
1485-
all_sum = buf[tiisg];
1486-
all_sum = simd_sum(all_sum);
1487-
}
1488-
1489-
const float mean = all_sum/ne00;
1490-
const float scale = 1.0f/sqrt(mean + eps);
1491-
1492-
device float4 * y = (device float4 *) (dst + tgpig*ne00);
1493-
device float4 * z = (device float4 *)src1;
1494-
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
1495-
y[i00] = x[i00] * z[i00] * scale;
1496-
}
1497-
}
1498-
14991448
kernel void kernel_group_norm(
15001449
device const float * src0,
15011450
device float * dst,

0 commit comments

Comments
 (0)