Skip to content

Commit e7fb001

Browse files
committed
Merge branch 'master' into crokeso
2 parents e8b62e5 + 343b6e9 commit e7fb001

File tree

8 files changed

+347
-33
lines changed

8 files changed

+347
-33
lines changed

ggml/include/ggml.h

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2060,6 +2060,12 @@ extern "C" {
20602060
enum ggml_scale_mode {
20612061
GGML_SCALE_MODE_NEAREST = 0,
20622062
GGML_SCALE_MODE_BILINEAR = 1,
2063+
2064+
GGML_SCALE_MODE_COUNT
2065+
};
2066+
2067+
enum ggml_scale_flag {
2068+
GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8)
20632069
};
20642070

20652071
// interpolate
@@ -2072,14 +2078,26 @@ extern "C" {
20722078

20732079
// interpolate
20742080
// interpolate scale to specified dimensions
2075-
GGML_API struct ggml_tensor * ggml_upscale_ext(
2081+
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_upscale_ext(
20762082
struct ggml_context * ctx,
20772083
struct ggml_tensor * a,
20782084
int ne0,
20792085
int ne1,
20802086
int ne2,
20812087
int ne3,
2082-
enum ggml_scale_mode mode);
2088+
enum ggml_scale_mode mode),
2089+
"use ggml_interpolate instead");
2090+
2091+
// Up- or downsamples the input to the specified size.
2092+
// 2D scale modes (eg. bilinear) are applied to the first two dimensions.
2093+
GGML_API struct ggml_tensor * ggml_interpolate(
2094+
struct ggml_context * ctx,
2095+
struct ggml_tensor * a,
2096+
int64_t ne0,
2097+
int64_t ne1,
2098+
int64_t ne2,
2099+
int64_t ne3,
2100+
uint32_t mode); // ggml_scale_mode [ | ggml_scale_flag...]
20832101

20842102
// pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
20852103
GGML_API struct ggml_tensor * ggml_pad(

ggml/src/ggml-cpu/ops.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8471,12 +8471,13 @@ static void ggml_compute_forward_upscale_f32(
84718471

84728472
GGML_TENSOR_UNARY_OP_LOCALS
84738473

8474-
const float sf0 = (float)ne0/src0->ne[0];
8475-
const float sf1 = (float)ne1/src0->ne[1];
8476-
const float sf2 = (float)ne2/src0->ne[2];
8477-
const float sf3 = (float)ne3/src0->ne[3];
8474+
float sf0 = (float)ne0/src0->ne[0];
8475+
float sf1 = (float)ne1/src0->ne[1];
8476+
float sf2 = (float)ne2/src0->ne[2];
8477+
float sf3 = (float)ne3/src0->ne[3];
84788478

8479-
const ggml_scale_mode mode = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0);
8479+
const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
8480+
const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
84808481

84818482
if (mode == GGML_SCALE_MODE_NEAREST) {
84828483
for (int64_t i3 = 0; i3 < ne3; i3++) {
@@ -8497,8 +8498,12 @@ static void ggml_compute_forward_upscale_f32(
84978498
}
84988499
}
84998500
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
8500-
// setting a pixel offset of 0 would replicate the behavior of pytorch interpolate with align_corners=True
8501-
const float pixel_offset = 0.5f;
8501+
float pixel_offset = 0.5f;
8502+
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
8503+
pixel_offset = 0.0f;
8504+
sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1);
8505+
sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1);
8506+
}
85028507

85038508
for (int64_t i3 = 0; i3 < ne3; i3++) {
85048509
const int64_t i03 = i3 / sf3;
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2+
3+
#define GELU_COEF_A 0.044715f
4+
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
5+
6+
//------------------------------------------------------------------------------
7+
// geglu
8+
//------------------------------------------------------------------------------
9+
kernel void kernel_geglu(
10+
global char * src0,
11+
ulong offset0,
12+
global char * src1,
13+
ulong offset1,
14+
global char * dst,
15+
ulong offsetd,
16+
ulong nb01,
17+
ulong nb11,
18+
int ne0,
19+
ulong nb1,
20+
int ne00_off,
21+
int ne10_off
22+
) {
23+
src0 = (global char*)((global char*)src0 + offset0);
24+
src1 = (global char*)((global char*)src1 + offset1);
25+
dst = (global char*)((global char*)dst + offsetd);
26+
27+
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
28+
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
29+
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
30+
31+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
32+
const float x0 = src0_row[i0];
33+
const float x1 = src1_row[i0];
34+
35+
const float gelu = 0.5f*x0*(1.0f + tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
36+
37+
dst_row[i0] = gelu*x1;
38+
}
39+
}
40+
41+
kernel void kernel_geglu_f16(
42+
global char * src0,
43+
ulong offset0,
44+
global char * src1,
45+
ulong offset1,
46+
global char * dst,
47+
ulong offsetd,
48+
ulong nb01,
49+
ulong nb11,
50+
int ne0,
51+
ulong nb1,
52+
int ne00_off,
53+
int ne10_off
54+
) {
55+
src0 = (global char*)((global char*)src0 + offset0);
56+
src1 = (global char*)((global char*)src1 + offset1);
57+
dst = (global char*)((global char*)dst + offsetd);
58+
59+
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
60+
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
61+
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
62+
63+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
64+
const half x0 = src0_row[i0];
65+
const half x1 = src1_row[i0];
66+
67+
const half gelu = 0.5f*x0*(1.0f + tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
68+
69+
dst_row[i0] = gelu*x1;
70+
}
71+
}
72+
73+
//------------------------------------------------------------------------------
74+
// reglu
75+
//------------------------------------------------------------------------------
76+
kernel void kernel_reglu(
77+
global char * src0,
78+
ulong offset0,
79+
global char * src1,
80+
ulong offset1,
81+
global char * dst,
82+
ulong offsetd,
83+
ulong nb01,
84+
ulong nb11,
85+
int ne0,
86+
ulong nb1,
87+
int ne00_off,
88+
int ne10_off
89+
) {
90+
src0 = (global char*)((global char*)src0 + offset0);
91+
src1 = (global char*)((global char*)src1 + offset1);
92+
dst = (global char*)((global char*)dst + offsetd);
93+
94+
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
95+
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
96+
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
97+
98+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
99+
const float x0 = src0_row[i0];
100+
const float x1 = src1_row[i0];
101+
102+
dst_row[i0] = x0*x1*(x0 > 0.0f);
103+
}
104+
}
105+
106+
kernel void kernel_reglu_f16(
107+
global char * src0,
108+
ulong offset0,
109+
global char * src1,
110+
ulong offset1,
111+
global char * dst,
112+
ulong offsetd,
113+
ulong nb01,
114+
ulong nb11,
115+
int ne0,
116+
ulong nb1,
117+
int ne00_off,
118+
int ne10_off
119+
) {
120+
src0 = (global char*)((global char*)src0 + offset0);
121+
src1 = (global char*)((global char*)src1 + offset1);
122+
dst = (global char*)((global char*)dst + offsetd);
123+
124+
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
125+
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
126+
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
127+
128+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
129+
const half x0 = src0_row[i0];
130+
const half x1 = src1_row[i0];
131+
132+
dst_row[i0] = x0*x1*(x0 > 0.0f);
133+
}
134+
}
135+
136+
//------------------------------------------------------------------------------
137+
// swiglu
138+
//------------------------------------------------------------------------------
139+
kernel void kernel_swiglu(
140+
global char * src0,
141+
ulong offset0,
142+
global char * src1,
143+
ulong offset1,
144+
global char * dst,
145+
ulong offsetd,
146+
ulong nb01,
147+
ulong nb11,
148+
int ne0,
149+
ulong nb1,
150+
int ne00_off,
151+
int ne10_off
152+
) {
153+
src0 = (global char*)((global char*)src0 + offset0);
154+
src1 = (global char*)((global char*)src1 + offset1);
155+
dst = (global char*)((global char*)dst + offsetd);
156+
157+
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
158+
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
159+
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
160+
161+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
162+
const float x0 = src0_row[i0];
163+
const float x1 = src1_row[i0];
164+
165+
const float silu = x0 / (1.0f + exp(-x0));
166+
167+
dst_row[i0] = silu*x1;
168+
}
169+
}
170+
171+
kernel void kernel_swiglu_f16(
172+
global char * src0,
173+
ulong offset0,
174+
global char * src1,
175+
ulong offset1,
176+
global char * dst,
177+
ulong offsetd,
178+
ulong nb01,
179+
ulong nb11,
180+
int ne0,
181+
ulong nb1,
182+
int ne00_off,
183+
int ne10_off
184+
) {
185+
src0 = (global char*)((global char*)src0 + offset0);
186+
src1 = (global char*)((global char*)src1 + offset1);
187+
dst = (global char*)((global char*)dst + offsetd);
188+
189+
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
190+
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
191+
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
192+
193+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
194+
const half x0 = src0_row[i0];
195+
const half x1 = src1_row[i0];
196+
197+
const half silu = x0 / (1.0f + exp(-x0));
198+
199+
dst_row[i0] = silu*x1;
200+
}
201+
}

ggml/src/ggml-quants.c

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -741,14 +741,14 @@ static float make_qkx2_quants(int n, int nmax, const float * GGML_RESTRICT x, co
741741
}
742742
float iscale = nmax/(max - min);
743743
float scale = 1/iscale;
744-
float best_mad = 0;
744+
float best_error = 0;
745745
for (int i = 0; i < n; ++i) {
746746
int l = nearest_int(iscale*(x[i] - min));
747747
L[i] = MAX(0, MIN(nmax, l));
748748
float diff = scale * L[i] + min - x[i];
749749
diff = use_mad ? fabsf(diff) : diff * diff;
750750
float w = weights[i];
751-
best_mad += w * diff;
751+
best_error += w * diff;
752752
}
753753
if (nstep < 1) {
754754
*the_min = -min;
@@ -774,18 +774,18 @@ static float make_qkx2_quants(int n, int nmax, const float * GGML_RESTRICT x, co
774774
this_min = 0;
775775
this_scale = sum_xl / sum_l2;
776776
}
777-
float mad = 0;
777+
float cur_error = 0;
778778
for (int i = 0; i < n; ++i) {
779779
float diff = this_scale * Laux[i] + this_min - x[i];
780780
diff = use_mad ? fabsf(diff) : diff * diff;
781781
float w = weights[i];
782-
mad += w * diff;
782+
cur_error += w * diff;
783783
}
784-
if (mad < best_mad) {
784+
if (cur_error < best_error) {
785785
for (int i = 0; i < n; ++i) {
786786
L[i] = Laux[i];
787787
}
788-
best_mad = mad;
788+
best_error = cur_error;
789789
scale = this_scale;
790790
min = this_min;
791791
}

0 commit comments

Comments
 (0)