Skip to content

Commit f17568f

Browse files
committed
Compute progress
- Create ComputePipeline, add command encoder methods to diispatch work and use pipeline - Add storageTexture type for compute shaders in bind groups
1 parent 32494a5 commit f17568f

File tree

15 files changed

+392
-231
lines changed

15 files changed

+392
-231
lines changed

packages/arisu-gfx/bind_group.lua

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,16 @@ BindGroup.__index = BindGroup
88
--- | "FRAGMENT"
99
--- | "COMPUTE"
1010

11+
---@alias gfx.StorageAccess
12+
--- | "READ_ONLY"
13+
--- | "WRITE_ONLY"
14+
--- | "READ_WRITE"
15+
1116
---@alias gfx.BindGroupEntry
1217
--- | { type: "buffer", binding: number, buffer: gfx.Buffer, visibility: gfx.ShaderStage[] }
1318
--- | { type: "sampler", binding: number, sampler: gfx.Sampler, visibility: gfx.ShaderStage[] }
1419
--- | { type: "texture", binding: number, texture: gfx.Texture, visibility: gfx.ShaderStage[] }
20+
--- | { type: "storageTexture", binding: number, texture: gfx.Texture, access: gfx.StorageAccess, visibility: gfx.ShaderStage[] }
1521

1622
---@param entries gfx.BindGroupEntry[]
1723
function BindGroup.new(entries)

packages/arisu-gfx/command_buffer/gl.lua

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ local pipelines = setmetatable({}, {
4444
__mode = "k",
4545
})
4646

47+
---@type table<gfx.gl.ComputePipeline, gfx.gl.RawComputePipeline>
48+
local computePipelines = setmetatable({}, {
49+
__mode = "k",
50+
})
51+
4752
---@type table<gfx.IndexFormat, number>
4853
local indexFormatToGL = {
4954
[gfx.IndexType.u16] = gl.UNSIGNED_SHORT,
@@ -62,7 +67,17 @@ local compareFnsMap = {
6267
[gfx.CompareFunction.ALWAYS] = gl.ALWAYS,
6368
}
6469

70+
---@type table<gfx.StorageAccess, number>
71+
local accessMap = {
72+
["READ_ONLY"] = gl.READ_ONLY,
73+
["WRITE_ONLY"] = gl.WRITE_ONLY,
74+
["READ_WRITE"] = gl.READ_WRITE,
75+
}
76+
6577
function GLCommandBuffer:execute()
78+
---@type gfx.gl.ComputePipeline?
79+
local computePipeline
80+
6681
---@type gfx.gl.Pipeline?
6782
local pipeline
6883

@@ -164,10 +179,26 @@ function GLCommandBuffer:execute()
164179
elseif entry.type == "sampler" then
165180
local sampler = entry.sampler --[[@as gfx.gl.Sampler]]
166181
gl.bindSampler(entry.binding, sampler.id)
182+
elseif entry.type == "storageTexture" then
183+
-- TODO: Look into the format here
184+
local texture = entry.texture --[[@as gfx.gl.Texture]]
185+
gl.bindImageTexture(entry.binding, texture.id, 0, 1, 0, accessMap[entry.access], gl.RGBA8)
167186
end
168187
end
169188
elseif command.type == "drawIndexed" then
170189
gl.drawElements(gl.TRIANGLES, command.indexCount, indexType, nil)
190+
elseif command.type == "beginComputePass" then
191+
elseif command.type == "setComputePipeline" then
192+
computePipeline = command.pipeline
193+
194+
local rawComputePipeline = computePipelines[computePipeline]
195+
if not rawComputePipeline then
196+
rawComputePipeline = computePipeline:genForCurrentContext()
197+
computePipelines[computePipeline] = rawComputePipeline
198+
end
199+
rawComputePipeline:bind()
200+
elseif command.type == "dispatchWorkgroups" then
201+
gl.dispatchCompute(command.x, command.y, command.z)
171202
else
172203
print("Unknown command type: " .. tostring(command.type))
173204
end

packages/arisu-gfx/command_encoder.lua

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
---@field colorAttachments { op: gfx.LoadOp, texture: gfx.Texture }[]
2020
---@field depthStencilAttachment? { op: gfx.DepthOp, texture: gfx.Texture }
2121

22+
---@class gfx.ComputePassDescriptor
23+
2224
---@class gfx.CommandEncoder
2325
---@field finish fun(self: gfx.CommandEncoder): gfx.CommandBuffer
2426
---@field beginRendering fun(self: gfx.CommandEncoder, descriptor: gfx.RenderPassDescriptor)
@@ -32,6 +34,10 @@
3234
---@field drawIndexed fun(self: gfx.CommandEncoder, indexCount: number, instanceCount: number, firstIndex: number?, baseVertex: number?, firstInstance: number?)
3335
---@field writeBuffer fun(self: gfx.CommandEncoder, buffer: gfx.Buffer, size: number, data: ffi.cdata*, offset: number?)
3436
---@field writeTexture fun(self: gfx.CommandEncoder, texture: gfx.Texture, descriptor: gfx.TextureWriteDescriptor, data: ffi.cdata*)
37+
--- Compute
38+
---@field beginComputePass fun(self: gfx.CommandEncoder, descriptor: gfx.ComputePassDescriptor)
39+
---@field dispatchWorkgroups fun(self: gfx.CommandEncoder, x: number, y: number, z: number)
40+
---@field setComputePipeline fun(self: gfx.CommandEncoder, pipeline: gfx.ComputePipeline)
3541
local Encoder = require("arisu-gfx.encoder.gl") --[[@as gfx.CommandEncoder]]
3642

3743
return Encoder

packages/arisu-gfx/command_encoder/gl.lua

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ local GLCommandBuffer = require("arisu-gfx.command_buffer.gl")
1212
---| { type: "drawIndexed", indexCount: number, instanceCount: number, firstIndex: number, baseVertex: number, firstInstance: number }
1313
---| { type: "writeBuffer", buffer: gfx.gl.Buffer, size: number, data: ffi.cdata*, offset: number }
1414
---| { type: "writeTexture", texture: gfx.gl.Texture, descriptor: gfx.TextureWriteDescriptor, data: ffi.cdata* }
15+
--- # Compute
16+
---| { type: "beginComputePass", descriptor: gfx.ComputePassDescriptor }
17+
---| { type: "dispatchWorkgroups", x: number, y: number, z: number }
18+
---| { type: "setComputePipeline", pipeline: gfx.gl.ComputePipeline }
1519

1620
---@class gfx.gl.Encoder
1721
---@field commands gfx.gl.Command[]
@@ -119,6 +123,27 @@ function GLCommandEncoder:writeTexture(texture, descriptor, data)
119123
}
120124
end
121125

126+
--[[
127+
Compute Functions
128+
]]
129+
130+
---@param descriptor gfx.ComputePassDescriptor
131+
function GLCommandEncoder:beginComputePass(descriptor)
132+
self.commands[#self.commands + 1] = { type = "beginComputePass", descriptor = descriptor }
133+
end
134+
135+
---@param x number
136+
---@param y number
137+
---@param z number
138+
function GLCommandEncoder:dispatchWorkgroups(x, y, z)
139+
self.commands[#self.commands + 1] = { type = "dispatchWorkgroups", x = x, y = y, z = z }
140+
end
141+
142+
---@param pipeline gfx.gl.ComputePipeline
143+
function GLCommandEncoder:setComputePipeline(pipeline)
144+
self.commands[#self.commands + 1] = { type = "setComputePipeline", pipeline = pipeline }
145+
end
146+
122147
function GLCommandEncoder:finish()
123148
return GLCommandBuffer.new(self.commands)
124149
end
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---@class gfx.ComputePipelineDescriptor
2+
---@field module gfx.ShaderModule
3+
4+
---@class gfx.ComputePipeline
5+
local ComputePipeline = require("arisu-gfx.compute_pipeline.gl") --[[@as gfx.ComputePipeline]]
6+
7+
return ComputePipeline
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
local gl = require("arisu-opengl")
2+
local ffi = require("ffi")
3+
4+
local GLProgram = require("arisu-gfx.gl_program")
5+
6+
---@class gfx.gl.ComputePipeline
7+
---@field module gfx.ShaderModule
8+
local GLComputePipeline = {}
9+
GLComputePipeline.__index = GLComputePipeline
10+
11+
---@param device gfx.gl.Device
12+
---@param descriptor gfx.ComputePipelineDescriptor
13+
function GLComputePipeline.new(device, descriptor)
14+
if descriptor.module.type ~= "glsl" then
15+
error("Only GLSL shaders are supported in the OpenGL backend.")
16+
end
17+
18+
return setmetatable({ module = descriptor.module }, GLComputePipeline)
19+
end
20+
21+
---@class gfx.gl.RawComputePipeline
22+
---@field id number
23+
local GLRawComputePipeline = {}
24+
GLRawComputePipeline.__index = GLRawComputePipeline
25+
26+
function GLRawComputePipeline.new(id)
27+
return setmetatable({ id = id }, GLRawComputePipeline)
28+
end
29+
30+
--- GL Specific pipeline generation for the current context
31+
--- This isn't done at pipeline creation time because contexts cannot share pipelines.
32+
function GLComputePipeline:genForCurrentContext()
33+
local pipeline = gl.genProgramPipelines(1)[1]
34+
35+
local program = GLProgram.new(gl.ShaderType.COMPUTE, self.module.source)
36+
gl.useProgramStages(pipeline, gl.COMPUTE_SHADER_BIT, program.id)
37+
38+
return GLRawComputePipeline.new(pipeline)
39+
end
40+
41+
function GLRawComputePipeline:bind()
42+
gl.bindProgramPipeline(self.id)
43+
end
44+
45+
function GLRawComputePipeline:destroy()
46+
gl.deleteProgramPipelines(1, ffi.new("GLuint[1]", self.id))
47+
end
48+
49+
function GLRawComputePipeline:__tostring()
50+
return "GLRawComputePipeline(" .. tostring(self.id) .. ")"
51+
end
52+
53+
return GLComputePipeline

packages/arisu-gfx/device.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
---@field createBindGroup fun(self: gfx.Device, entries: gfx.BindGroupEntry[]): gfx.BindGroup
77
---@field createSampler fun(self: gfx.Device, descriptor: gfx.SamplerDescriptor): gfx.Sampler
88
---@field createTexture fun(self: gfx.Device, descriptor: gfx.TextureDescriptor): gfx.Texture
9+
---@field createComputePipeline fun(self: gfx.Device, descriptor: gfx.ComputePipelineDescriptor): gfx.ComputePipeline
910
local Device = require("arisu-gfx.device.gl") --[[@as gfx.Device]]
1011

1112
return Device

packages/arisu-gfx/device/gl.lua

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ local GLPipeline = require("arisu-gfx.pipeline.gl")
66
local GLBindGroup = require("arisu-gfx.bind_group")
77
local GLSampler = require("arisu-gfx.sampler.gl")
88
local GLTexture = require("arisu-gfx.texture.gl")
9+
local GLComputePipeline = require("arisu-gfx.compute_pipeline.gl")
910

1011
---@class gfx.gl.Device
1112
---@field public queue gfx.gl.Queue
@@ -54,4 +55,10 @@ function GLDevice:createTexture(descriptor)
5455
return GLTexture.new(self, descriptor)
5556
end
5657

58+
---@param descriptor gfx.ComputePipelineDescriptor
59+
function GLDevice:createComputePipeline(descriptor)
60+
self.ctx:makeCurrent()
61+
return GLComputePipeline.new(self, descriptor)
62+
end
63+
5764
return GLDevice

packages/arisu-gfx/gl_context/x11.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function X11Context.new(display, sharedCtx, window)
3030

3131
local ctx = glx.createContextAttribsARB(display, fbConfig, sharedCtx and sharedCtx.ctx, 1, {
3232
glx.CONTEXT_MAJOR_VERSION_ARB,
33-
3,
33+
4,
3434
glx.CONTEXT_MINOR_VERSION_ARB,
3535
3,
3636
})

packages/arisu-gfx/init.lua

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,8 @@ gfx.IndexType = {
6060
u32 = 2,
6161
}
6262

63+
---@alias gfx.ShaderModule
64+
---| { type: "glsl", source: string }
65+
---| { type: "spirv", source: string }
66+
6367
return gfx

0 commit comments

Comments
 (0)