Skip to content

Commit 73b08c6

Browse files
committed
Add compute shader 2x2 box downsampling demo
1 parent 60926aa commit 73b08c6

File tree

4 files changed

+356
-0
lines changed

4 files changed

+356
-0
lines changed
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Copyright LWJGL. All rights reserved.
3+
* License terms: https://www.lwjgl.org/license
4+
*/
5+
#version 430 core
6+
#extension GL_KHR_shader_subgroup_shuffle : require
7+
8+
layout(location=0) uniform sampler2D baseImage;
9+
layout(binding=0, rgba16f) uniform writeonly restrict image2D mips[3];
10+
11+
/*
12+
* The assumption here is that each subgroup item maps to its corresponding local workgroup item
13+
* according to gl_LocalInvocationID.x % gl_SubgroupSize == gl_SubgroupInvocationID.
14+
* Our workgroups are 256 = 16 * 16 items in size.
15+
*
16+
* We will use z-order / morton-curve to layout the 256 threads in a workgroup
17+
* across a 16x16 grid. That means, we still use (width/16, height/16, 1) workgroups
18+
* to process the baseImage. We just redistribute the work items on a different
19+
* 2D pattern.
20+
*/
21+
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
22+
23+
/*
24+
* Morton code unpack to generate (x, y) pair from a z-order curve coordinate within [0..255].
25+
*/
26+
int unpack(int x) {
27+
x &= 0x55;
28+
x = (x ^ (x >> 1)) & 0x33;
29+
x = (x ^ (x >> 2)) & 0x0f;
30+
return x;
31+
}
32+
33+
void main(void) {
34+
ivec2 ts = textureSize(baseImage, 0);
35+
36+
// the actual size of our work items is only half the baseImage size, because for the first mip level
37+
// each work item already uses linear filtering with a sampler to gather a 2x2 texel average
38+
ivec2 s = ts / ivec2(2);
39+
40+
// Compute the (x, y) coordinates of the current work item within its workgroup using z-order curve
41+
ivec2 l = ivec2(unpack(int(gl_LocalInvocationID.x)),
42+
unpack(int(gl_LocalInvocationID.x >> 1u)));
43+
44+
// Compute the global (x, y) coordinate of this work item
45+
ivec2 i = ivec2(gl_WorkGroupID.xy) * ivec2(16) + l;
46+
47+
// compute mip 1 using linear filtering
48+
if (i.x >= s.x || i.y >= s.y)
49+
return;
50+
// Compute a texture coordinate right at the corner between four texels
51+
vec2 tc = (vec2(i * 2) + vec2(1.0)) / vec2(ts);
52+
vec4 t = textureLod(baseImage, tc, 0.0);
53+
imageStore(mips[0], i, t);
54+
55+
// compute mip 2 using subgroup quad sharing
56+
/*
57+
* The trick here is to assume a 1:1 correspondence between subgroup invocation ids
58+
* and workgroup invocation ids (modulus the subgroup size).
59+
* This way, together with our assumed Z-order swizzled layout, we know that
60+
* for the subgroup [0, 1, 2, 3] forming a single 2x2 quad, e.g. the horizontal swap
61+
* will come out correctly as [1, 0, 3, 2], etc.
62+
*/
63+
vec4 h = subgroupShuffleXor(t, 1);
64+
vec4 v = subgroupShuffleXor(t, 2);
65+
vec4 d = subgroupShuffleXor(t, 3);
66+
t = (t + h + v + d) * vec4(0.25);
67+
if ((gl_SubgroupInvocationID & 3) == 0)
68+
imageStore(mips[1], i/ivec2(2), t);
69+
70+
// compute mip 3 using subgroup xor shuffles
71+
/*
72+
* The trick here is to exchange information between subgroup items with a stride
73+
* of 4 items. In order to do this, we have subgroupShuffleXor().
74+
*/
75+
h = subgroupShuffleXor(t, 4);
76+
v = subgroupShuffleXor(t, 8);
77+
d = subgroupShuffleXor(t, 12);
78+
t = (t + h + v + d) * vec4(0.25);
79+
if ((gl_SubgroupInvocationID & 15) == 0)
80+
imageStore(mips[2], i/ivec2(4), t);
81+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
/*
2+
* Copyright LWJGL. All rights reserved.
3+
* License terms: https://www.lwjgl.org/license
4+
*/
5+
#version 430 core
6+
7+
uniform sampler2D tex;
8+
uniform int level;
9+
in vec2 texcoord;
10+
out vec4 color;
11+
12+
void main(void) {
13+
color = textureLod(tex, texcoord, float(level));
14+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
/*
2+
* Copyright LWJGL. All rights reserved.
3+
* License terms: https://www.lwjgl.org/license
4+
*/
5+
#version 430 core
6+
7+
out vec2 texcoord;
8+
9+
void main(void) {
10+
vec2 vertex = vec2(-1.0) + vec2(
11+
float((gl_VertexID & 1) << 2),
12+
float((gl_VertexID & 2) << 1));
13+
gl_Position = vec4(vertex, 0.0, 1.0);
14+
texcoord = vertex * 0.5 + vec2(0.5, 0.5);
15+
}
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
/*
2+
* Copyright LWJGL. All rights reserved.
3+
* License terms: https://www.lwjgl.org/license
4+
*/
5+
package org.lwjgl.demo.opengl.shader;
6+
7+
import static java.lang.Math.*;
8+
import static org.lwjgl.glfw.Callbacks.glfwFreeCallbacks;
9+
import static org.lwjgl.glfw.GLFW.*;
10+
import static org.lwjgl.opengl.GL43C.*;
11+
import static org.lwjgl.system.MemoryUtil.*;
12+
13+
import java.io.IOException;
14+
import java.nio.ByteBuffer;
15+
import java.nio.IntBuffer;
16+
17+
import org.lwjgl.demo.opengl.util.DemoUtils;
18+
import org.lwjgl.opengl.GL;
19+
import org.lwjgl.opengl.GLCapabilities;
20+
import org.lwjgl.opengl.GLUtil;
21+
import org.lwjgl.system.Callback;
22+
import org.lwjgl.system.MemoryStack;
23+
import org.lwjgl.system.MemoryUtil;
24+
25+
/**
26+
* Computes 3 mip levels of a texture using only a single compute shader dispatch
27+
* and GL_KHR_shader_subgroup.
28+
*
29+
* @author Kai Burjack
30+
*/
31+
public class DownsamplingDemo {
32+
33+
private static long window;
34+
private static int width = 1024;
35+
private static int height = 768;
36+
private static boolean resetTexture;
37+
38+
private static int nullVao;
39+
private static int computeProgram;
40+
private static int quadProgram;
41+
private static int texture;
42+
private static int levelUniform;
43+
private static int level;
44+
45+
private static Callback debugProc;
46+
47+
private static void createNullVao() {
48+
nullVao = glGenVertexArrays();
49+
}
50+
51+
private static void createQuadProgram() throws IOException {
52+
int program = glCreateProgram();
53+
int vshader = DemoUtils.createShader("org/lwjgl/demo/opengl/shader/downsampling/quad.vs.glsl",
54+
GL_VERTEX_SHADER);
55+
int fshader = DemoUtils.createShader("org/lwjgl/demo/opengl/shader/downsampling/quad.fs.glsl",
56+
GL_FRAGMENT_SHADER);
57+
glAttachShader(program, vshader);
58+
glAttachShader(program, fshader);
59+
glBindFragDataLocation(program, 0, "color");
60+
glLinkProgram(program);
61+
glDetachShader(program, vshader);
62+
glDetachShader(program, fshader);
63+
glDeleteShader(vshader);
64+
glDeleteShader(fshader);
65+
int linked = glGetProgrami(program, GL_LINK_STATUS);
66+
String programLog = glGetProgramInfoLog(program);
67+
if (programLog.trim().length() > 0) {
68+
System.err.println(programLog);
69+
}
70+
if (linked == 0) {
71+
throw new AssertionError("Could not link program");
72+
}
73+
int texUniform = glGetUniformLocation(program, "tex");
74+
levelUniform = glGetUniformLocation(program, "level");
75+
glUseProgram(program);
76+
glUniform1i(texUniform, 0);
77+
glUseProgram(0);
78+
quadProgram = program;
79+
}
80+
81+
private static void createComputeProgram() throws IOException {
82+
int program = glCreateProgram();
83+
int cshader = DemoUtils.createShader("org/lwjgl/demo/opengl/shader/downsampling/downsample.cs.glsl",
84+
GL_COMPUTE_SHADER);
85+
glAttachShader(program, cshader);
86+
glLinkProgram(program);
87+
glDetachShader(program, cshader);
88+
glDeleteShader(cshader);
89+
int linked = glGetProgrami(program, GL_LINK_STATUS);
90+
String programLog = glGetProgramInfoLog(program);
91+
if (programLog.trim().length() > 0) {
92+
System.err.println(programLog);
93+
}
94+
if (linked == 0) {
95+
throw new AssertionError("Could not link program");
96+
}
97+
computeProgram = program;
98+
}
99+
100+
private static byte v(int i) {
101+
return i % 8 == 0 ? (byte) 255 : (byte) 0;
102+
}
103+
104+
private static void createTextures() {
105+
texture = glGenTextures();
106+
glBindTexture(GL_TEXTURE_2D, texture);
107+
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST_MIPMAP_LINEAR);
108+
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR);
109+
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE);
110+
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_R, GL_CLAMP_TO_EDGE);
111+
glTexStorage2D(GL_TEXTURE_2D, 4, GL_RGBA16F, width, height);
112+
ByteBuffer pixels = MemoryUtil.memAlloc(width * height * 4);
113+
// fill the first level of the texture with some pattern
114+
for (int y = 0; y < height; y++)
115+
for (int x = 0; x < width; x++)
116+
pixels.put(v(x)).put(v(y)).put((byte) 127).put((byte) 255);
117+
pixels.flip();
118+
glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, width, height, GL_RGBA, GL_UNSIGNED_BYTE, pixels);
119+
MemoryUtil.memFree(pixels);
120+
glBindTexture(GL_TEXTURE_2D, 0);
121+
}
122+
123+
private static void downsample() {
124+
glUseProgram(computeProgram);
125+
126+
// read mip level 0
127+
glBindTexture(GL_TEXTURE_2D, texture);
128+
// write mip levels 1-3
129+
glBindImageTexture(0, texture, 1, false, 0, GL_WRITE_ONLY, GL_RGBA16F);
130+
glBindImageTexture(1, texture, 2, false, 0, GL_WRITE_ONLY, GL_RGBA16F);
131+
glBindImageTexture(2, texture, 3, false, 0, GL_WRITE_ONLY, GL_RGBA16F);
132+
133+
int texelsPerWorkItem = 2;
134+
int numGroupsX = (int) ceil((double) width / texelsPerWorkItem / 16);
135+
int numGroupsY = (int) ceil((double) height / texelsPerWorkItem / 16);
136+
137+
glDispatchCompute(numGroupsX, numGroupsY, 1);
138+
glMemoryBarrier(GL_SHADER_IMAGE_ACCESS_BARRIER_BIT);
139+
140+
/* Reset bindings. */
141+
glBindTexture(GL_TEXTURE_2D, 0);
142+
glBindImageTexture(0, 0, 1, false, 0, GL_WRITE_ONLY, GL_RGBA16F);
143+
glBindImageTexture(1, 0, 2, false, 0, GL_WRITE_ONLY, GL_RGBA16F);
144+
glBindImageTexture(2, 0, 3, false, 0, GL_WRITE_ONLY, GL_RGBA16F);
145+
glUseProgram(0);
146+
}
147+
148+
private static void present() {
149+
glClear(GL_COLOR_BUFFER_BIT);
150+
glViewport(0, 0, width/(1<<level), height/(1<<level));
151+
glUseProgram(quadProgram);
152+
glUniform1i(levelUniform, level);
153+
glBindVertexArray(nullVao);
154+
glBindTexture(GL_TEXTURE_2D, texture);
155+
glDrawArrays(GL_TRIANGLES, 0, 3);
156+
glBindTexture(GL_TEXTURE_2D, 0);
157+
glBindVertexArray(0);
158+
glUseProgram(0);
159+
}
160+
161+
private static void init() throws IOException {
162+
if (!glfwInit())
163+
throw new IllegalStateException("Unable to initialize GLFW");
164+
165+
glfwDefaultWindowHints();
166+
glfwWindowHint(GLFW_OPENGL_PROFILE, GLFW_OPENGL_CORE_PROFILE);
167+
glfwWindowHint(GLFW_OPENGL_FORWARD_COMPAT, GLFW_TRUE);
168+
glfwWindowHint(GLFW_CONTEXT_VERSION_MAJOR, 4);
169+
glfwWindowHint(GLFW_CONTEXT_VERSION_MINOR, 3);
170+
glfwWindowHint(GLFW_VISIBLE, GLFW_FALSE);
171+
glfwWindowHint(GLFW_RESIZABLE, GLFW_TRUE);
172+
173+
System.out.println("Press arrow up/down to increase/decrease the viewed mip level");
174+
window = glfwCreateWindow(width, height, "Downsampling Demo", NULL, NULL);
175+
if (window == NULL) {
176+
throw new AssertionError("Failed to create the GLFW window");
177+
}
178+
glfwSetKeyCallback(window, (wnd, key, scancode, action, mods) -> {
179+
if (key == GLFW_KEY_ESCAPE && action == GLFW_RELEASE)
180+
glfwSetWindowShouldClose(window, true);
181+
else if (key == GLFW_KEY_UP && action == GLFW_RELEASE)
182+
level = min(3, level + 1);
183+
else if (key == GLFW_KEY_DOWN && action == GLFW_RELEASE)
184+
level = max(0, level - 1);
185+
});
186+
glfwSetFramebufferSizeCallback(window, (wnd, w, h) -> {
187+
if (w > 0 && h > 0 && (width != w || height != h)) {
188+
width = w;
189+
height = h;
190+
resetTexture = true;
191+
}
192+
});
193+
194+
try (MemoryStack frame = MemoryStack.stackPush()) {
195+
IntBuffer framebufferSize = frame.mallocInt(2);
196+
nglfwGetFramebufferSize(window, memAddress(framebufferSize), memAddress(framebufferSize) + 4);
197+
width = framebufferSize.get(0);
198+
height = framebufferSize.get(1);
199+
}
200+
201+
glfwMakeContextCurrent(window);
202+
203+
GLCapabilities caps = GL.createCapabilities();
204+
// Check required extensions
205+
if (!caps.GL_KHR_shader_subgroup)
206+
throw new AssertionError("GL_KHR_shader_subgroup is required but not supported");
207+
208+
debugProc = GLUtil.setupDebugMessageCallback();
209+
210+
createTextures();
211+
createNullVao();
212+
createComputeProgram();
213+
createQuadProgram();
214+
215+
glfwShowWindow(window);
216+
}
217+
218+
private static void loop() {
219+
while (!glfwWindowShouldClose(window)) {
220+
glfwPollEvents();
221+
if (resetTexture) {
222+
glDeleteTextures(texture);
223+
createTextures();
224+
resetTexture = false;
225+
}
226+
downsample();
227+
present();
228+
glfwSwapBuffers(window);
229+
}
230+
}
231+
232+
private static void destroy() {
233+
if (debugProc != null)
234+
debugProc.free();
235+
glfwDestroyWindow(window);
236+
glfwFreeCallbacks(window);
237+
glfwTerminate();
238+
}
239+
240+
public static void main(String[] args) throws IOException {
241+
init();
242+
loop();
243+
destroy();
244+
}
245+
246+
}

0 commit comments

Comments
 (0)