Skip to content

Commit 071ee0a

Browse files
dakersankhesh
authored andcommitted
fix(WebGPU): optimize vtkTexture.generateMipmaps
This commit generates mipmaps for a given GPU texture using a compute shader. fixes #3260
1 parent eeffdf4 commit 071ee0a

File tree

4 files changed

+244
-171
lines changed

4 files changed

+244
-171
lines changed

Sources/Rendering/Core/Texture/index.d.ts

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,19 +98,21 @@ export function extend(
9898
export function newInstance(initialValues?: ITextureInitialValues): vtkTexture;
9999

100100
/**
101-
* Method used to create mipmaps from given texture data. Works best with textures that have a
102-
* width and a height that are powers of two.
101+
* Generates mipmaps for a given GPU texture using a compute shader.
103102
*
104-
* @param nativeArray the array of data to create mipmaps from.
105-
* @param width the width of the data
106-
* @param height the height of the data
107-
* @param level the level to which additional mipmaps are generated.
103+
* This function iteratively generates each mip level for the provided texture,
104+
* using a bilinear downsampling compute shader implemented in WGSL. It creates
105+
* the necessary pipeline, bind groups, and dispatches compute passes for each
106+
* mip level.
107+
*
108+
* @param {GPUDevice} device - The WebGPU device used to create resources and submit commands.
109+
* @param {GPUTexture} texture - The GPU texture for which mipmaps will be generated.
110+
* @param {number} mipLevelCount - The total number of mip levels to generate (including the base level).
108111
*/
109112
export function generateMipmaps(
110-
nativeArray: any,
111-
width: number,
112-
height: number,
113-
level: number
113+
device: any,
114+
texture: any,
115+
mipLevelCount: number
114116
): Array<Uint8ClampedArray>;
115117

116118
/**

Sources/Rendering/Core/Texture/index.js

Lines changed: 136 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
/* eslint-disable no-bitwise */
12
import macro from 'vtk.js/Sources/macros';
23

34
// ----------------------------------------------------------------------------
@@ -139,101 +140,146 @@ function vtkTexture(publicAPI, model) {
139140
};
140141
}
141142

142-
// Use nativeArray instead of self
143-
const generateMipmaps = (nativeArray, width, height, level) => {
144-
// TODO: FIX UNEVEN TEXTURE MIP GENERATION:
145-
// When textures don't have standard ratios, higher mip levels
146-
// result in their color chanels getting messed up and shifting
147-
// 3x3 gaussian kernel
148-
const g3m = [1, 2, 1]; // eslint-disable-line
149-
const g3w = 4; // eslint-disable-line
150-
// 5x5 gaussian kernel
151-
const g5m = [1, 2, 4, 2, 1]; // eslint-disable-line
152-
const g5w = 10; // eslint-disable-line
153-
// 7x7 gaussian kernel
154-
const g7m = [1, 2, 6, 8, 6, 2, 1]; // eslint-disable-line
155-
const g7w = 26; // eslint-disable-line
156-
157-
const kernel = g3m;
158-
const kernelWeight = g3w;
159-
160-
const hs = nativeArray.length / (width * height); // TODO: support for textures with depth more than 1
161-
let currentWidth = width;
162-
let currentHeight = height;
163-
let imageData = nativeArray;
164-
const maps = [imageData];
165-
166-
for (let i = 0; i < level; i++) {
167-
const oldData = [...imageData];
168-
currentWidth /= 2;
169-
currentHeight /= 2;
170-
imageData = new Uint8ClampedArray(currentWidth * currentHeight * hs);
171-
const vs = hs * currentWidth;
172-
173-
// Scale down
174-
let shift = 0;
175-
for (let p = 0; p < imageData.length; p += hs) {
176-
if (p % vs === 0) {
177-
shift += 2 * hs * currentWidth;
143+
/**
144+
* Generates mipmaps for a given GPU texture using a compute shader.
145+
*
146+
* This function iteratively generates each mip level for the provided texture,
147+
* using a bilinear downsampling compute shader implemented in WGSL. It creates
148+
* the necessary pipeline, bind groups, and dispatches compute passes for each
149+
* mip level.
150+
*
151+
* @param {GPUDevice} device - The WebGPU device used to create resources and submit commands.
152+
* @param {GPUTexture} texture - The GPU texture for which mipmaps will be generated. Must be created with mip levels.
153+
* @param {number} mipLevelCount - The total number of mip levels to generate (including the base level).
154+
*/
155+
const generateMipmaps = (device, texture, mipLevelCount) => {
156+
const computeShaderCode = `
157+
@group(0) @binding(0) var inputTexture: texture_2d<f32>;
158+
@group(0) @binding(1) var outputTexture: texture_storage_2d<rgba8unorm, write>;
159+
160+
@compute @workgroup_size(8, 8)
161+
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
162+
let texelCoord = vec2<i32>(global_id.xy);
163+
let outputSize = textureDimensions(outputTexture);
164+
165+
if (texelCoord.x >= i32(outputSize.x) || texelCoord.y >= i32(outputSize.y)) {
166+
return;
178167
}
179168
180-
for (let c = 0; c < hs; c++) {
181-
let sample = oldData[shift + c];
182-
sample += oldData[shift + hs + c];
183-
sample += oldData[shift - 2 * vs + c];
184-
sample += oldData[shift - 2 * vs + hs + c];
185-
sample /= 4;
186-
imageData[p + c] = sample;
187-
}
188-
189-
shift += 2 * hs;
190-
}
169+
let inputSize = textureDimensions(inputTexture);
170+
let scale = vec2<f32>(inputSize) / vec2<f32>(outputSize);
171+
172+
// Compute the floating-point source coordinate
173+
let srcCoord = (vec2<f32>(texelCoord) + 0.5) * scale - 0.5;
174+
175+
// Get integer coordinates for the four surrounding texels
176+
let x0 = i32(floor(srcCoord.x));
177+
let x1 = min(x0 + 1, i32(inputSize.x) - 1);
178+
let y0 = i32(floor(srcCoord.y));
179+
let y1 = min(y0 + 1, i32(inputSize.y) - 1);
180+
181+
// Compute the weights
182+
let wx = srcCoord.x - f32(x0);
183+
let wy = srcCoord.y - f32(y0);
184+
185+
// Fetch the four texels
186+
let c00 = textureLoad(inputTexture, vec2<i32>(x0, y0), 0);
187+
let c10 = textureLoad(inputTexture, vec2<i32>(x1, y0), 0);
188+
let c01 = textureLoad(inputTexture, vec2<i32>(x0, y1), 0);
189+
let c11 = textureLoad(inputTexture, vec2<i32>(x1, y1), 0);
190+
191+
// Bilinear interpolation
192+
let color = mix(
193+
mix(c00, c10, wx),
194+
mix(c01, c11, wx),
195+
wy
196+
);
191197
192-
// Horizontal Pass
193-
let dataCopy = [...imageData];
194-
for (let p = 0; p < imageData.length; p += hs) {
195-
for (let c = 0; c < hs; c++) {
196-
let x = -(kernel.length - 1) / 2;
197-
let kw = kernelWeight;
198-
let value = 0.0;
199-
for (let k = 0; k < kernel.length; k++) {
200-
let index = p + c + x * hs;
201-
const lineShift = (index % vs) - ((p + c) % vs);
202-
if (lineShift > hs) index += vs;
203-
if (lineShift < -hs) index -= vs;
204-
if (dataCopy[index]) {
205-
value += dataCopy[index] * kernel[k];
206-
} else {
207-
kw -= kernel[k];
208-
}
209-
x += 1;
210-
}
211-
imageData[p + c] = value / kw;
212-
}
213-
}
214-
// Vertical Pass
215-
dataCopy = [...imageData];
216-
for (let p = 0; p < imageData.length; p += hs) {
217-
for (let c = 0; c < hs; c++) {
218-
let x = -(kernel.length - 1) / 2;
219-
let kw = kernelWeight;
220-
let value = 0.0;
221-
for (let k = 0; k < kernel.length; k++) {
222-
const index = p + c + x * vs;
223-
if (dataCopy[index]) {
224-
value += dataCopy[index] * kernel[k];
225-
} else {
226-
kw -= kernel[k];
227-
}
228-
x += 1;
229-
}
230-
imageData[p + c] = value / kw;
231-
}
198+
textureStore(outputTexture, texelCoord, color);
232199
}
233-
234-
maps.push(imageData);
200+
`;
201+
202+
const computeShader = device.createShaderModule({
203+
code: computeShaderCode,
204+
});
205+
206+
const bindGroupLayout = device.createBindGroupLayout({
207+
entries: [
208+
{
209+
binding: 0,
210+
// eslint-disable-next-line no-undef
211+
visibility: GPUShaderStage.COMPUTE,
212+
texture: { sampleType: 'float' },
213+
},
214+
{
215+
binding: 1,
216+
// eslint-disable-next-line no-undef
217+
visibility: GPUShaderStage.COMPUTE,
218+
storageTexture: { format: 'rgba8unorm', access: 'write-only' },
219+
},
220+
{
221+
binding: 2,
222+
// eslint-disable-next-line no-undef
223+
visibility: GPUShaderStage.COMPUTE,
224+
sampler: { type: 'filtering' },
225+
},
226+
],
227+
});
228+
229+
const pipelineLayout = device.createPipelineLayout({
230+
bindGroupLayouts: [bindGroupLayout],
231+
});
232+
233+
const pipeline = device.createComputePipeline({
234+
layout: pipelineLayout,
235+
compute: {
236+
module: computeShader,
237+
entryPoint: 'main',
238+
},
239+
});
240+
241+
const sampler = device.createSampler({
242+
magFilter: 'linear',
243+
minFilter: 'linear',
244+
});
245+
246+
// Generate each mip level
247+
for (let mipLevel = 1; mipLevel < mipLevelCount; mipLevel++) {
248+
const srcView = texture.createView({
249+
baseMipLevel: mipLevel - 1,
250+
mipLevelCount: 1,
251+
});
252+
253+
const dstView = texture.createView({
254+
baseMipLevel: mipLevel,
255+
mipLevelCount: 1,
256+
});
257+
258+
const bindGroup = device.createBindGroup({
259+
layout: pipeline.getBindGroupLayout(0),
260+
entries: [
261+
{ binding: 0, resource: srcView },
262+
{ binding: 1, resource: dstView },
263+
{ binding: 2, resource: sampler },
264+
],
265+
});
266+
267+
const commandEncoder = device.createCommandEncoder();
268+
const computePass = commandEncoder.beginComputePass();
269+
270+
computePass.setPipeline(pipeline);
271+
computePass.setBindGroup(0, bindGroup);
272+
273+
const mipWidth = Math.max(1, texture.width >> mipLevel);
274+
const mipHeight = Math.max(1, texture.height >> mipLevel);
275+
const workgroupsX = Math.ceil(mipWidth / 8);
276+
const workgroupsY = Math.ceil(mipHeight / 8);
277+
278+
computePass.dispatchWorkgroups(workgroupsX, workgroupsY);
279+
computePass.end();
280+
281+
device.queue.submit([commandEncoder.finish()]);
235282
}
236-
return maps;
237283
};
238284

239285
// ----------------------------------------------------------------------------

0 commit comments

Comments
 (0)