diff --git a/Makefile b/Makefile index 4bdf0e260e..7ec9b02db7 100644 --- a/Makefile +++ b/Makefile @@ -50,6 +50,9 @@ unittest: test-oclgrind: cabal run -- futhark test tests -c --backend=opencl --exclude=compiled --exclude=no_oclgrind --cache-extension=cache --pass-option=--build-option=-O0 --runner=tools/oclgrindrunner.sh +test-webgpu: + cabal run -- futhark test tests -c --backend=webgpu --runner tools/browser_test.py --no-tuning + test-t: cabal run -- futhark test tests -t diff --git a/cabal.project b/cabal.project index f55a150ff4..c66545d56d 100644 --- a/cabal.project +++ b/cabal.project @@ -1,6 +1,6 @@ -packages: futhark.cabal +packages: futhark.cabal, language-wgsl index-state: 2025-07-03T02:06:05Z -multi-repl: True +packages: futhark.cabal package futhark ghc-options: -j -fwrite-ide-info -hiedir=.hie diff --git a/default.nix b/default.nix index 801f289edc..721b85a6cd 100644 --- a/default.nix +++ b/default.nix @@ -34,6 +34,9 @@ let futhark-manifest = haskellPackagesNew.callPackage ./nix/futhark-manifest.nix { }; + language-wgsl = + haskellPackagesOld.callCabal2nix "language-wgsl" ./language-wgsl {}; + futhark = # callCabal2Nix does not do a great job at determining # which files must be included as source, which causes diff --git a/futhark.cabal b/futhark.cabal index c54d094da6..4acbeaeb5e 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -47,6 +47,7 @@ extra-source-files: rts/c/backends/hip.h rts/c/backends/multicore.h rts/c/backends/opencl.h + rts/c/backends/webgpu.h rts/c/lock.h rts/c/copy.h rts/c/timing.h @@ -81,6 +82,23 @@ extra-source-files: rts/python/values.py rts/python/opencl.py rts/python/scalar.py + rts/webgpu/server_ws.js + rts/webgpu/util.js + rts/webgpu/values.js + rts/webgpu/wrappers.js + rts/wgsl/scalar.wgsl + rts/wgsl/scalar8.wgsl + rts/wgsl/scalar16.wgsl + rts/wgsl/scalar32.wgsl + rts/wgsl/scalar64.wgsl + rts/wgsl/atomics.wgsl + rts/wgsl/builtin_kernels.wgsl + rts/wgsl/lmad_copy.wgsl + rts/wgsl/map_transpose.wgsl + rts/wgsl/map_transpose_low_height.wgsl + rts/wgsl/map_transpose_low_width.wgsl + rts/wgsl/map_transpose_small.wgsl + rts/wgsl/map_transpose_large.wgsl prelude/functional.fut prelude/math.fut prelude/soacs.fut @@ -178,8 +196,10 @@ library Futhark.CLI.Script Futhark.CLI.Test Futhark.CLI.WASM + Futhark.CLI.WebGPU Futhark.CodeGen.Backends.CCUDA Futhark.CodeGen.Backends.COpenCL + Futhark.CodeGen.Backends.CWebGPU Futhark.CodeGen.Backends.HIP Futhark.CodeGen.Backends.GenericC Futhark.CodeGen.Backends.GenericC.CLI @@ -212,11 +232,15 @@ library Futhark.CodeGen.RTS.OpenCL Futhark.CodeGen.RTS.Python Futhark.CodeGen.RTS.JavaScript + Futhark.CodeGen.RTS.WebGPU + Futhark.CodeGen.RTS.WGSL Futhark.CodeGen.ImpCode Futhark.CodeGen.ImpCode.GPU + Futhark.CodeGen.ImpCode.Kernels Futhark.CodeGen.ImpCode.Multicore Futhark.CodeGen.ImpCode.OpenCL Futhark.CodeGen.ImpCode.Sequential + Futhark.CodeGen.ImpCode.WebGPU Futhark.CodeGen.ImpGen Futhark.CodeGen.ImpGen.CUDA Futhark.CodeGen.ImpGen.GPU @@ -238,6 +262,7 @@ library Futhark.CodeGen.ImpGen.Multicore.SegScan Futhark.CodeGen.ImpGen.OpenCL Futhark.CodeGen.ImpGen.Sequential + Futhark.CodeGen.ImpGen.WebGPU Futhark.CodeGen.OpenCL.Heuristics Futhark.Compiler Futhark.Compiler.CLI @@ -383,6 +408,7 @@ library Futhark.Script Futhark.Test Futhark.Test.Spec + Futhark.Test.WebGPUTest Futhark.Test.Values Futhark.Tools Futhark.Transform.CopyPropagate @@ -498,6 +524,7 @@ library , mwc-random , prettyprinter >= 1.7 , prettyprinter-ansi-terminal >= 1.1 + , language-wgsl executable futhark import: common diff --git a/language-wgsl/CHANGELOG.md b/language-wgsl/CHANGELOG.md new file mode 100644 index 0000000000..15b63dceba --- /dev/null +++ b/language-wgsl/CHANGELOG.md @@ -0,0 +1,5 @@ +# Revision history for language-wgsl + +## 0.1.0.0 -- YYYY-mm-dd + +* First version. Released on an unsuspecting world. diff --git a/language-wgsl/LICENSE b/language-wgsl/LICENSE new file mode 100644 index 0000000000..7b9e44dd3a --- /dev/null +++ b/language-wgsl/LICENSE @@ -0,0 +1,13 @@ +Copyright (c) 2024 Sebastian Paarmann and the University of Copenhagen + +Permission to use, copy, modify, and/or distribute this software for any purpose +with or without fee is hereby granted, provided that the above copyright notice +and this permission notice appear in all copies. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH +REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND +FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, +INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS +OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER +TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF +THIS SOFTWARE. diff --git a/language-wgsl/language-wgsl.cabal b/language-wgsl/language-wgsl.cabal new file mode 100644 index 0000000000..fcb12946ad --- /dev/null +++ b/language-wgsl/language-wgsl.cabal @@ -0,0 +1,43 @@ +cabal-version: 3.4 +name: language-wgsl +version: 0.1.0.0 +synopsis: The WebGPU Shader Language. + +description: A package implementing a (partial) AST definition for the + WebGPU Shader Language, including a prettyprinter. + +license: ISC +license-file: LICENSE +author: Sebastian Paarmann +maintainer: sepa@di.ku.dk +-- copyright: +category: Language +build-type: Simple +extra-doc-files: CHANGELOG.md +-- extra-source-files: + +common warnings + ghc-options: -Wall + +library + import: warnings + exposed-modules: Language.WGSL + build-depends: base >=4.15 && <5, + text, + prettyprinter >= 1.7 + hs-source-dirs: src + default-language: GHC2021 + default-extensions: + OverloadedStrings + +test-suite language-wgsl-test + import: warnings + default-language: GHC2021 + default-extensions: + OverloadedStrings + type: exitcode-stdio-1.0 + hs-source-dirs: test + main-is: Main.hs + build-depends: + base >=4.15 && <5, + language-wgsl diff --git a/language-wgsl/src/Language/WGSL.hs b/language-wgsl/src/Language/WGSL.hs new file mode 100644 index 0000000000..d11f377931 --- /dev/null +++ b/language-wgsl/src/Language/WGSL.hs @@ -0,0 +1,363 @@ +module Language.WGSL + ( Ident, + PrimType (..), + hsLayout, + structLayout, + Typ (..), + BinOp, + UnOp, + Exp (..), + Stmt (..), + Attrib (..), + Param (..), + Function (..), + Field (..), + Struct (..), + Declaration (..), + AccessMode (..), + AddressSpace (..), + stmts, + bindingAttribs, + to_i32, + prettyDecls, + ) +where + +import Data.Maybe (fromMaybe) +import Data.Text qualified as T +import Prettyprinter + +type Ident = T.Text + +data PrimType + = Bool + | Int32 + | UInt32 + | Float16 + | Float32 + | Vec2 PrimType + | Vec3 PrimType + | Vec4 PrimType + | Atomic PrimType + deriving (Show) + +-- | AlignOf and SizeOf of host-shareable primitive types. +-- +-- See https://www.w3.org/TR/WGSL/#alignment-and-size. +hsLayout :: PrimType -> Maybe (Int, Int) +hsLayout Bool = Nothing +hsLayout Int32 = Just (4, 4) +hsLayout UInt32 = Just (4, 4) +hsLayout Float16 = Just (2, 2) +hsLayout Float32 = Just (4, 4) +hsLayout (Vec2 t) = + (\(a, s) -> (a * 2, s * 2)) <$> hsLayout t +hsLayout (Vec3 t) = + (\(a, s) -> (a * 4, s * 3)) <$> hsLayout t +hsLayout (Vec4 t) = + (\(a, s) -> (a * 4, s * 4)) <$> hsLayout t +hsLayout (Atomic t) = hsLayout t + +-- | Field offsets, AlignOf, and SizeOf of a host-shareable struct type with the +-- given fields. +structLayout :: [PrimType] -> Maybe ([Int], Int, Int) +structLayout [] = Nothing +structLayout fields = do + fieldLayouts <- mapM hsLayout fields + let (fieldAligns, fieldSizes) = unzip fieldLayouts + let structAlign = maximum fieldAligns + let fieldOffsets = + scanl + (\prev_off (al, prev_sz) -> roundUp al (prev_off + prev_sz)) + 0 + (zip (drop 1 fieldAligns) fieldSizes) + let structSize = roundUp structAlign (last fieldOffsets + last fieldSizes) + pure (fieldOffsets, structAlign, structSize) + where + roundUp k n = ceiling ((fromIntegral n :: Double) / fromIntegral k) * k + +data Typ = Prim PrimType | Array PrimType (Maybe Exp) | Named Ident | Pointer PrimType AddressSpace (Maybe AccessMode) + +type BinOp = T.Text + +type UnOp = T.Text + +data Exp + = BoolExp Bool + | IntExp Int + | FloatExp Double + | StringExp T.Text + | VarExp Ident + | BinOpExp BinOp Exp Exp + | UnOpExp UnOp Exp + | CallExp Ident [Exp] + | IndexExp Ident Exp + | FieldExp Exp Ident + +data Stmt + = Skip + | Comment T.Text + | Seq Stmt Stmt + | Let Ident Exp + | DeclareVar Ident Typ + | Assign Ident Exp + | AssignIndex Ident Exp Exp + | If Exp Stmt Stmt + | For Ident Exp Exp Stmt Stmt + | While Exp Stmt + | Call Ident [Exp] + +data Attrib = Attrib Ident [Exp] + +data Param = Param Ident Typ [Attrib] + +data Function = Function + { funName :: Ident, + funAttribs :: [Attrib], + funParams :: [Param], + funOutput :: [Param], + funBody :: Stmt + } + +data Field = Field Ident Typ + +data Struct = Struct Ident [Field] + +data AccessMode = ReadOnly | ReadWrite + +-- Uniform buffers are always read-only. +data AddressSpace = Storage AccessMode | Uniform | Workgroup | FunctionSpace + +data Declaration + = FunDecl Function + | StructDecl Struct + | VarDecl [Attrib] AddressSpace Ident Typ + | OverrideDecl Ident Typ (Maybe Exp) + +stmts :: [Stmt] -> Stmt +stmts [] = Skip +stmts [s] = s +stmts (s : ss) = Seq s (stmts ss) + +bindingAttribs :: Int -> Int -> [Attrib] +bindingAttribs grp binding = + [Attrib "group" [IntExp grp], Attrib "binding" [IntExp binding]] + +to_i32 :: Exp -> Exp +to_i32 e = CallExp "bitcast" [e] + +--- Prettyprinting definitions + +-- | Separate with commas. +commasep :: [Doc a] -> Doc a +commasep = hsep . punctuate comma + +-- | Like commasep, but a newline after every comma. +commastack :: [Doc a] -> Doc a +commastack = align . vsep . punctuate comma + +-- | Separate with semicolons and newlines. +semistack :: [Doc a] -> Doc a +semistack = align . vsep . punctuate semi + +-- | Separate with linebreaks. +stack :: [Doc a] -> Doc a +stack = align . mconcat . punctuate line + +() :: Doc a -> Doc a -> Doc a +a b = a <> line <> b + +instance Pretty PrimType where + pretty Bool = "bool" + pretty Int32 = "i32" + pretty UInt32 = "u32" + pretty Float16 = "f16" + pretty Float32 = "f32" + pretty (Vec2 t) = "vec2<" <> pretty t <> ">" + pretty (Vec3 t) = "vec3<" <> pretty t <> ">" + pretty (Vec4 t) = "vec4<" <> pretty t <> ">" + pretty (Atomic t) = "atomic<" <> pretty t <> ">" + +instance Pretty Typ where + pretty (Prim t) = pretty t + pretty (Array t Nothing) = "array<" <> pretty t <> ">" + pretty (Array t sz) = "array<" <> pretty t <> ", " <> pretty sz <> ">" + pretty (Named t) = pretty t + pretty (Pointer t as am) = "ptr<" <> pretty as <> ", " <> pretty t <> maybe "" pretty am <> ">" + +instance Pretty Exp where + pretty (BoolExp True) = "true" + pretty (BoolExp False) = "false" + pretty (IntExp x) = pretty x + pretty (FloatExp x) = pretty x + pretty (StringExp x) = pretty $ show x + pretty (VarExp x) = pretty x + pretty (UnOpExp op e) = parens (pretty op <> pretty e) + pretty (BinOpExp op e1 e2) = parens (pretty e1 <+> pretty op <+> pretty e2) + pretty (CallExp f args) = pretty f <> parens (commasep $ map pretty args) + pretty (IndexExp x i) = pretty x <> brackets (pretty i) + pretty (FieldExp x y) = pretty x <> "." <> pretty y + +instance Pretty Stmt where + pretty Skip = ";" + pretty (Comment c) = vsep (map ("//" <+>) (pretty <$> T.lines c)) + pretty (Seq s1 s2) = semistack [pretty s1, pretty s2] + pretty (Let x e) = "let" <+> pretty x <+> "=" <+> pretty e + pretty (DeclareVar x t) = "var" <+> pretty x <+> ":" <+> pretty t + pretty (Assign x e) = pretty x <+> "=" <+> pretty e + pretty (AssignIndex x i e) = + pretty x <> brackets (pretty i) <+> "=" <+> pretty e + pretty (If cond Skip Skip) = "if" <+> pretty cond <+> "{ }" + pretty (If cond th Skip) = + "if" + <+> pretty cond + <+> "{" + indent 2 (pretty th) + <> ";" + "}" + pretty (If cond Skip el) = + "if" + <+> pretty cond + <+> "{ }" + "else {" + indent 2 (pretty el) + <> ";" + "}" + pretty (If cond th el) = + "if" + <+> pretty cond + <+> "{" + indent 2 (pretty th) + <> ";" + "} else {" + indent 2 (pretty el) + <> ";" + "}" + pretty (For x initializer cond upd body) = + "for" + <+> parens + ( "var" + <+> pretty x + <+> "=" + <+> pretty initializer + <> ";" + <+> pretty cond + <> ";" + <+> pretty upd + ) + <+> "{" + indent 2 (pretty body) + <> ";" + "}" + pretty (While cond body) = + "while" + <+> pretty cond + <+> "{" + indent 2 (pretty body) + <> ";" + "}" + pretty (Call f args) = pretty f <> parens (commasep $ map pretty args) + +instance Pretty Attrib where + pretty (Attrib name []) = "@" <> pretty name + pretty (Attrib name args) = + "@" <> pretty name <> parens (commasep $ map pretty args) + +instance Pretty Param where + pretty (Param name typ attribs) + | null attribs = pretty name <+> ":" <+> pretty typ + | otherwise = + stack + [ hsep (map pretty attribs), + pretty name <+> ":" <+> pretty typ + ] + +prettyParams :: [Param] -> Doc a +prettyParams [] = "()" +prettyParams params = "(" indent 2 (commastack (map pretty params)) ")" + +prettyAssignOutParams :: [Param] -> Doc a +prettyAssignOutParams [] = "" +prettyAssignOutParams params = stack (map prettyAssign params) + where + prettyAssign (Param name _ _) = + indent 2 "*" <> pretty name <> " = " <> pretty (T.stripSuffix "_out" name) <> ";" + +instance Pretty Function where + pretty (Function name attribs in_params out_params body) = do + stack $ hsep (map pretty attribs) : function + where + funBody = indent 2 (pretty body) <> ";" + funDecls = + let local_decls = + map + ( \(Param v typ _) -> case typ of + Pointer t _ _ -> DeclareVar (fromMaybe v (T.stripSuffix "_out" v)) (Prim t) + _ -> error "Can only return primitive types!" + ) + out_params + in stack (map (\decl -> indent 2 (pretty decl) <> ";") local_decls) + function = case out_params of + [] -> + ["fn" <+> pretty name <> prettyParams in_params <+> "{", funBody, "}"] + [Param ret_id (Pointer t _ _) _] -> + [ "fn" <+> pretty name <> prettyParams in_params <+> "->" <+> pretty t <+> "{", + funDecls, + funBody, + indent 2 "return " <> pretty (T.stripSuffix "_out" ret_id) <> ";", + "}" + ] + _ -> + [ "fn" <+> pretty name <> prettyParams (in_params ++ out_params) <+> "{", + funDecls, + funBody, + prettyAssignOutParams out_params, + "}" + ] + +instance Pretty Field where + pretty (Field name typ) = pretty name <+> ":" <+> pretty typ + +instance Pretty Struct where + pretty (Struct name fields) = + "struct" + <+> pretty name + <+> "{" + indent 2 (commastack (map pretty fields)) + "}" + +instance Pretty AccessMode where + pretty ReadOnly = "read" + pretty ReadWrite = "read_write" + +instance Pretty AddressSpace where + pretty (Storage am) = "storage" <> "," <> pretty am + pretty Uniform = "uniform" + pretty Workgroup = "workgroup" + pretty FunctionSpace = "function" + +instance Pretty Declaration where + pretty (FunDecl fun) = pretty fun + pretty (StructDecl struct) = pretty struct + pretty (VarDecl attribs as name typ) = + hsep (map pretty attribs) + "var<" + <> pretty as + <> ">" + <+> pretty name + <+> ":" + <+> pretty typ + <> ";" + pretty (OverrideDecl name typ Nothing) = + "override" <+> pretty name <+> ":" <+> pretty typ <> ";" + pretty (OverrideDecl name typ (Just initial)) = + "override" + <+> pretty name + <+> ":" + <+> pretty typ + <+> "=" + <+> pretty initial + <> ";" + +prettyDecls :: [Declaration] -> Doc a +prettyDecls decls = stack (map pretty decls) diff --git a/language-wgsl/test/Main.hs b/language-wgsl/test/Main.hs new file mode 100644 index 0000000000..3e2059e31f --- /dev/null +++ b/language-wgsl/test/Main.hs @@ -0,0 +1,4 @@ +module Main (main) where + +main :: IO () +main = putStrLn "Test suite not yet implemented." diff --git a/rts/c/backends/opencl.h b/rts/c/backends/opencl.h index 73a02a58c9..70d6a2ebb5 100644 --- a/rts/c/backends/opencl.h +++ b/rts/c/backends/opencl.h @@ -1370,7 +1370,7 @@ static int gpu_launch_kernel(struct futhark_context* ctx, // will go through the free list). static int gpu_alloc_actual(struct futhark_context *ctx, size_t size, gpu_mem *mem_out) { int error; - *mem_out = clCreateBuffer(ctx->ctx, CL_MEM_READ_WRITE, size, NULL, &error); + *mem_out = clCreateBuffer(ctx->ctx, CL_MEM_READ_WRITE, size+1024, NULL, &error); OPENCL_SUCCEED_OR_RETURN(error); diff --git a/rts/c/backends/webgpu.h b/rts/c/backends/webgpu.h new file mode 100644 index 0000000000..f69f7a582b --- /dev/null +++ b/rts/c/backends/webgpu.h @@ -0,0 +1,1025 @@ +// Start of backends/webgpu.h. + +// Synchronous wrapper around asynchronous WebGPU APIs, based on looping with +// emscripten_sleep until the respective callback gets called. + +typedef struct wgpu_wait_info { + bool released; + void *result; +} wgpu_wait_info; + +void wgpu_map_sync_callback(WGPUBufferMapAsyncStatus status, void *info_v) { + wgpu_wait_info *info = (wgpu_wait_info *)info_v; + *((WGPUBufferMapAsyncStatus *) info->result) = status; + info->released = true; +} + +WGPUBufferMapAsyncStatus wgpu_map_buffer_sync(WGPUInstance instance, + WGPUBuffer buffer, + WGPUMapModeFlags mode, + size_t offset, size_t size) { + WGPUBufferMapAsyncStatus status; + wgpu_wait_info info = { + .released = false, + .result = (void *)&status, + }; + +#ifdef USE_DAWN + WGPUBufferMapCallbackInfo cb_info = { + .mode = WGPUCallbackMode_WaitAnyOnly, + .callback = wgpu_map_sync_callback, + .userdata = (void *) &info, + }; + WGPUFuture f = wgpuBufferMapAsyncF(buffer, mode, offset, size, cb_info); + WGPUFutureWaitInfo f_info = { .future = f }; + while (!info.released) { + wgpuInstanceWaitAny(instance, 1, &f_info, 0); + } +#else + wgpuBufferMapAsync(buffer, mode, offset, size, + wgpu_map_sync_callback, (void *) &info); + + // TODO: Should this do some kind of volatile load? + // (Same for other _sync wrappers below.) + while (!info.released) { + emscripten_sleep(0); + } +#endif + + return status; +} + +typedef struct wgpu_request_adapter_result { + WGPURequestAdapterStatus status; + WGPUAdapter adapter; + const char *message; +} wgpu_request_adapter_result; + +void wgpu_request_adapter_callback(WGPURequestAdapterStatus status, + WGPUAdapter adapter, + const char *message, void *userdata) { + wgpu_wait_info *info = (wgpu_wait_info *)userdata; + wgpu_request_adapter_result *result + = (wgpu_request_adapter_result *)info->result; + result->status = status; + result->adapter = adapter; + result->message = message; + info->released = true; +} + +wgpu_request_adapter_result wgpu_request_adapter_sync( + WGPUInstance instance, WGPURequestAdapterOptions const * options) { + wgpu_request_adapter_result result = {}; + wgpu_wait_info info = { + .released = false, + .result = (void *)&result, + }; + +#ifdef USE_DAWN + WGPURequestAdapterCallbackInfo cb_info = { + .mode = WGPUCallbackMode_WaitAnyOnly, + .callback = wgpu_request_adapter_callback, + .userdata = (void *) &info, + }; + WGPUFuture f = wgpuInstanceRequestAdapterF(instance, options, cb_info); + WGPUFutureWaitInfo f_info = { .future = f }; + while (!info.released) { + wgpuInstanceWaitAny(instance, 1, &f_info, 0); + } +#else + wgpuInstanceRequestAdapter(instance, options, wgpu_request_adapter_callback, + (void *)&info); + + while (!info.released) { + emscripten_sleep(0); + } +#endif + + return result; +} + +typedef struct wgpu_request_device_result { + WGPURequestDeviceStatus status; + WGPUDevice device; + const char *message; +} wgpu_request_device_result; + +void wgpu_request_device_callback(WGPURequestDeviceStatus status, + WGPUDevice device, + const char *message, void *userdata) { + wgpu_wait_info *info = (wgpu_wait_info *)userdata; + wgpu_request_device_result *result + = (wgpu_request_device_result *)info->result; + result->status = status; + result->device = device; + result->message = message; + info->released = true; +} + +wgpu_request_device_result wgpu_request_device_sync( + WGPUInstance instance, + WGPUAdapter adapter, + WGPUDeviceDescriptor const * descriptor +) { + wgpu_request_device_result result = {}; + wgpu_wait_info info = { + .released = false, + .result = (void *)&result, + }; + +#ifdef USE_DAWN + WGPURequestDeviceCallbackInfo cb_info = { + .mode = WGPUCallbackMode_WaitAnyOnly, + .callback = wgpu_request_device_callback, + .userdata = (void *) &info, + }; + WGPUFuture f = wgpuAdapterRequestDeviceF(adapter, descriptor, cb_info); + WGPUFutureWaitInfo f_info = { .future = f }; + while (!info.released) { + wgpuInstanceWaitAny(instance, 1, &f_info, 0); + } +#else + wgpuAdapterRequestDevice(adapter, descriptor, wgpu_request_device_callback, + (void *)&info); + + while (!info.released) { + emscripten_sleep(0); + } +#endif + + return result; +} + +void wgpu_on_work_done_callback(WGPUQueueWorkDoneStatus status, + void *userdata) { + wgpu_wait_info *info = (wgpu_wait_info *)userdata; + *((WGPUQueueWorkDoneStatus *)info->result) = status; + info->released = true; +} + +WGPUQueueWorkDoneStatus wgpu_block_until_work_done(WGPUInstance instance, + WGPUQueue queue) { + WGPUQueueWorkDoneStatus status; + wgpu_wait_info info = { + .released = false, + .result = (void *)&status, + }; + + +#ifdef USE_DAWN + WGPUQueueWorkDoneCallbackInfo cb_info = { + .mode = WGPUCallbackMode_WaitAnyOnly, + .callback = wgpu_on_work_done_callback, + .userdata = (void *) &info, + }; + WGPUFuture f = wgpuQueueOnSubmittedWorkDoneF(queue, cb_info); + WGPUFutureWaitInfo f_info = { .future = f }; + while (!info.released) { + wgpuInstanceWaitAny(instance, 1, &f_info, 0); + } +#else + wgpuQueueOnSubmittedWorkDone(queue, wgpu_on_work_done_callback, (void *)&info); + + while (!info.released) { + emscripten_sleep(0); + } +#endif + + return status; +} + +void wgpu_on_uncaptured_error(WGPUErrorType error_type, const char *msg, + void *userdata) { + futhark_panic(-1, "Uncaptured WebGPU error, type: %d\n%s\n", error_type, msg); +} + +void wgpu_on_shader_compiled(WGPUCompilationInfoRequestStatus status, + struct WGPUCompilationInfo const * compilationInfo, + void * userdata) { + // TODO: Check status, better printing + for (int i = 0; i < compilationInfo->messageCount; i++) { + WGPUCompilationMessage msg = compilationInfo->messages[i]; + printf("Shader compilation message: %s\n", msg.message); + } +} + +struct futhark_context_config { + int in_use; + int debugging; + int profiling; + int logging; + char *cache_fname; + int num_tuning_params; + int64_t *tuning_params; + const char** tuning_param_names; + const char** tuning_param_vars; + const char** tuning_param_classes; + // Uniform fields above. + + char *program; + + struct gpu_config gpu; +}; + +static void backend_context_config_setup(struct futhark_context_config *cfg) { + cfg->program = strconcat(gpu_program); + + cfg->gpu.default_block_size = 256; + cfg->gpu.default_grid_size = 0; // Set properly later. + cfg->gpu.default_tile_size = 32; + cfg->gpu.default_reg_tile_size = 2; + cfg->gpu.default_threshold = 32*1024; + + cfg->gpu.default_block_size_changed = 0; + cfg->gpu.default_grid_size_changed = 0; + cfg->gpu.default_tile_size_changed = 0; +} + +static void backend_context_config_teardown(struct futhark_context_config *cfg) { + free(cfg->program); +} + +const char* futhark_context_config_get_program(struct futhark_context_config *cfg) { + return cfg->program; +} + +void futhark_context_config_set_program(struct futhark_context_config *cfg, const char *s) { + free(cfg->program); + cfg->program = strdup(s); +} + +struct futhark_context { + struct futhark_context_config* cfg; + int detail_memory; + int debugging; + int profiling; + int profiling_paused; + int logging; + lock_t lock; + char *error; + lock_t error_lock; + FILE *log; + struct constants *constants; + struct free_list free_list; + struct event_list event_list; + int64_t peak_mem_usage_default; + int64_t cur_mem_usage_default; + struct program* program; + bool program_initialised; + // Uniform fields above. + + struct tuning_params tuning_params; + // True if a potentially failing kernel has been enqueued. + int32_t failure_is_an_option; + int total_runs; + long int total_runtime; + int64_t peak_mem_usage_device; + int64_t cur_mem_usage_device; + + int num_overrides; + char **override_names; + double *override_values; + + WGPUInstance instance; + WGPUAdapter adapter; + WGPUDevice device; + WGPUQueue queue; + // One module contains all the kernels as separate entry points. + WGPUShaderModule module; + + WGPUBuffer scalar_readback_buffer; + struct free_list gpu_free_list; + + size_t lockstep_width; + size_t max_thread_block_size; + size_t max_grid_size; + size_t max_tile_size; + size_t max_threshold; + size_t max_shared_memory; + size_t max_registers; + size_t max_cache; + + struct builtin_kernels* kernels; +}; + +int futhark_context_sync(struct futhark_context *ctx) { + // TODO: All the error handling stuff. + WGPUQueueWorkDoneStatus status = wgpu_block_until_work_done(ctx->instance, + ctx->queue); + if (status != WGPUQueueWorkDoneStatus_Success) { + futhark_panic(-1, "Failed to wait for work to be done, status: %d\n", + status); + } + return FUTHARK_SUCCESS; +} + +static void wgpu_size_setup(struct futhark_context *ctx) { + struct futhark_context_config *cfg = ctx->cfg; + // TODO: Deal with the device limits here, see cuda.h. + + // TODO: See if we can also do some proper heuristic for default_grid_size + // here. + if (!cfg->gpu.default_grid_size_changed) { + cfg->gpu.default_grid_size = 16; + } + + for (int i = 0; i < cfg->num_tuning_params; i++) { + const char *size_class = cfg->tuning_param_classes[i]; + int64_t *size_value = &cfg->tuning_params[i]; + const char* size_name = cfg->tuning_param_names[i]; + //int64_t max_value = 0; + int64_t default_value = 0; + + if (strstr(size_class, "thread_block_size") == size_class) { + //max_value = ctx->max_thread_block_size; + default_value = cfg->gpu.default_block_size; + } else if (strstr(size_class, "grid_size") == size_class) { + //max_value = ctx->max_grid_size; + default_value = cfg->gpu.default_grid_size; + // XXX: as a quick and dirty hack, use twice as many threads for + // histograms by default. We really should just be smarter + // about sizes somehow. + if (strstr(size_name, ".seghist_") != NULL) { + default_value *= 2; + } + } else if (strstr(size_class, "tile_size") == size_class) { + //max_value = ctx->max_tile_size; + default_value = cfg->gpu.default_tile_size; + } else if (strstr(size_class, "reg_tile_size") == size_class) { + //max_value = 0; // No limit. + default_value = cfg->gpu.default_reg_tile_size; + } else if (strstr(size_class, "shared_memory") == size_class) { + default_value = ctx->max_shared_memory; + } else if (strstr(size_class, "cache") == size_class) { + default_value = ctx->max_cache; + } else if (strstr(size_class, "threshold") == size_class) { + // Threshold can be as large as it takes. + default_value = cfg->gpu.default_threshold; + } else { + // Bespoke sizes have no limit or default. + } + + if (*size_value == 0) { + *size_value = default_value; + //} else if (max_value > 0 && *size_value > max_value) { + // fprintf(stderr, "Note: Device limits %s to %zu (down from %zu)\n", + // size_name, max_value, *size_value); + // *size_value = max_value; + } + } +} + +void wgpu_module_setup(struct futhark_context *ctx, const char *program, WGPUShaderModule *module, const char* label) { + WGPUShaderModuleWGSLDescriptor wgsl_desc = { + .chain = { + .sType = WGPUSType_ShaderModuleWGSLDescriptor + }, + .code = program + }; + WGPUShaderModuleDescriptor desc = { + .label = label, + .nextInChain = &wgsl_desc.chain + }; + *module = wgpuDeviceCreateShaderModule(ctx->device, &desc); + + wgpuShaderModuleGetCompilationInfo(*module, wgpu_on_shader_compiled, NULL); +} + +struct builtin_kernels* init_builtin_kernels(struct futhark_context* ctx); +void free_builtin_kernels(struct futhark_context* ctx, struct builtin_kernels* kernels); + +int backend_context_setup(struct futhark_context *ctx) { + ctx->failure_is_an_option = 0; + ctx->total_runs = 0; + ctx->total_runtime = 0; + ctx->peak_mem_usage_device = 0; + ctx->cur_mem_usage_device = 0; + ctx->kernels = NULL; + + // These are the default limits from the spec, which will always be the actual + // limit unless we explicitly request a larger one (which we do not currently + // do). + ctx->max_thread_block_size = 256; + ctx->max_grid_size = 65536; // TODO: idk what these should be, just put a large enough value. + ctx->max_tile_size = 65536; // TODO: idk what these should be, just put a large enough value. + ctx->max_threshold = 65536; // TODO: idk what these should be, just put a large enough value. + ctx->max_shared_memory = 16384; + + ctx->max_registers = 65536; // TODO: idk what these should be, just put a large enough value. + + // This is a number we picked semi-arbitrarily (2 MiB). There does not seem to + // be a way to get L2 cache size from the WebGPU API. + ctx->max_cache = 2097152; + + ctx->instance = wgpuCreateInstance(NULL); + + wgpu_request_adapter_result adapter_result + = wgpu_request_adapter_sync(ctx->instance, NULL); + if (adapter_result.status != WGPURequestAdapterStatus_Success) { + if (adapter_result.message != NULL) { + futhark_panic(-1, "Could not get WebGPU adapter, status: %d\nMessage: %s\n", + adapter_result.status, adapter_result.message); + } else { + futhark_panic(-1, "Could not get WebGPU adapter, status: %d\n", + adapter_result.status); + } + } + ctx->adapter = adapter_result.adapter; + + // We want to request the max limits possible. + // Some limits, like maxStorageBuffersPerShaderStage has a huge impact + // on what programs we can run, so we need to request the maximum possible. + WGPUSupportedLimits supported; + WGPUBool res = wgpuAdapterGetLimits(ctx->adapter, &supported); + if (!res) { + futhark_panic(-1, "Could not get WebGPU adapter limits\n", res); + } + WGPURequiredLimits required_limits; + // If we just zero this memory, stuff crashes in the generated empscripten js. + // For some reason, we have to set it to all 1s? + memset((void*)&required_limits.limits, 0xff, sizeof(required_limits.limits)); + required_limits.limits.maxBindGroups = supported.limits.maxBindGroups; + required_limits.limits.maxBindingsPerBindGroup = supported.limits.maxBindingsPerBindGroup; + required_limits.limits.maxDynamicUniformBuffersPerPipelineLayout = supported.limits.maxDynamicUniformBuffersPerPipelineLayout; + required_limits.limits.maxDynamicStorageBuffersPerPipelineLayout = supported.limits.maxDynamicStorageBuffersPerPipelineLayout; + required_limits.limits.maxStorageBuffersPerShaderStage = supported.limits.maxStorageBuffersPerShaderStage; + required_limits.limits.maxUniformBuffersPerShaderStage = supported.limits.maxUniformBuffersPerShaderStage; + required_limits.limits.maxUniformBufferBindingSize = supported.limits.maxUniformBufferBindingSize; + required_limits.limits.maxStorageBufferBindingSize = supported.limits.maxStorageBufferBindingSize; + required_limits.limits.maxBufferSize = supported.limits.maxBufferSize; + required_limits.limits.maxComputeWorkgroupStorageSize = supported.limits.maxComputeWorkgroupStorageSize; + required_limits.limits.maxComputeInvocationsPerWorkgroup = supported.limits.maxComputeInvocationsPerWorkgroup; + required_limits.limits.maxComputeWorkgroupSizeX = supported.limits.maxComputeWorkgroupSizeX; + required_limits.limits.maxComputeWorkgroupSizeY = supported.limits.maxComputeWorkgroupSizeY; + required_limits.limits.maxComputeWorkgroupSizeZ = supported.limits.maxComputeWorkgroupSizeZ; + required_limits.limits.maxComputeWorkgroupsPerDimension = supported.limits.maxComputeWorkgroupsPerDimension; + + // Require support for 16-bit floats + WGPUFeatureName required_features[] = { WGPUFeatureName_ShaderF16 }; + WGPUDeviceDescriptor device_desc = { + .requiredFeatureCount = 1, + .requiredFeatures = required_features, + .requiredLimits = &required_limits + }; + wgpu_request_device_result device_result + = wgpu_request_device_sync(ctx->instance, ctx->adapter, &device_desc); + if (device_result.status != WGPURequestDeviceStatus_Success) { + if (device_result.message != NULL) { + futhark_panic(-1, "Could not get WebGPU device, status: %d\nMessage: %s\n", + device_result.status, device_result.message); + } else { + futhark_panic(-1, "Could not get WebGPU device, status: %d\n", + device_result.status); + } + } + ctx->device = device_result.device; + wgpuDeviceSetUncapturedErrorCallback(ctx->device, + wgpu_on_uncaptured_error, NULL); + + ctx->queue = wgpuDeviceGetQueue(ctx->device); + + wgpu_size_setup(ctx); + + WGPUBufferDescriptor desc = { + .label = "scalar_readback", + .size = 8, + .usage = WGPUBufferUsage_MapRead | WGPUBufferUsage_CopyDst, + }; + ctx->scalar_readback_buffer = wgpuDeviceCreateBuffer(ctx->device, &desc); + free_list_init(&ctx->gpu_free_list); + + // We implement macros as override constants. + int64_t *macro_vals; + ctx->num_overrides = gpu_macros(ctx, &ctx->override_names, + ¯o_vals); + ctx->override_values = malloc(ctx->num_overrides * sizeof(double)); + for (int i = 0; i < ctx->num_overrides; i++) { + ctx->override_values[i] = (double) macro_vals[i]; + } + free(macro_vals); + + wgpu_module_setup(ctx, ctx->cfg->program, &ctx->module, "Futhark program"); + + if ((ctx->kernels = init_builtin_kernels(ctx)) == NULL) { + printf("Failed to init builtin kernels\n"); + return 1; + } + + return 0; +} + +void backend_context_teardown(struct futhark_context *ctx) { + if (ctx->kernels != NULL) { + free_builtin_kernels(ctx, ctx->kernels); + } // TODO + free(ctx->override_names); + free(ctx->override_values); + + if (gpu_free_all(ctx) != FUTHARK_SUCCESS) { + futhark_panic(-1, "gpu_free_all failed"); + } + wgpuBufferDestroy(ctx->scalar_readback_buffer); + wgpuDeviceDestroy(ctx->device); + //} + free_list_destroy(&ctx->gpu_free_list); +} + +// Definitions for these are included as part of code generation. +// wgpu_kernel_info contains: +// char *name; +// +// size_t num_scalars; +// size_t scalars_binding; +// size_t scalars_size; +// size_t *scalar_offsets; +// +// size_t num_bindings; // excluding the scalars binding +// uint32_t *binding_indices; +// +// size_t num_overrides; +// char **used_overrides; + +// size_t num_dynamic_block_dims; +// uint32_t *dynamic_block_dim_indices; +// char **dynamic_block_dim_names; +// +// size_t num_shared_mem_overrides; +// char **shared_mem_overrides; +struct wgpu_kernel_info; +static size_t wgpu_num_kernel_infos; +static wgpu_kernel_info wgpu_kernel_infos[]; + +struct wgpu_kernel_info *wgpu_get_kernel_info(const char *name) { + for (int i = 0; i < wgpu_num_kernel_infos; i++) { + if (strcmp(name, wgpu_kernel_infos[i].name) == 0) { + return &wgpu_kernel_infos[i]; + } + } + + return NULL; +} + +// GPU ABSTRACTION LAYER + +// Types. +struct wgpu_kernel { + struct wgpu_kernel_info *info; + + WGPUBuffer scalars_buffer; + WGPUBindGroupLayout bind_group_layout; + WGPUPipelineLayout pipeline_layout; + + // True if we can create a single pipeline in `gpu_create_kernel`. If false, + // need to create a new pipeline for every kernel launch. + bool static_pipeline; + + // ShaderModule for this kernel. Generated from the gpu_program of the kernel + // info. + WGPUShaderModule module; + + // Only set if static_pipeline. + WGPUComputePipeline pipeline; + + // Only set if !static_pipeline. + WGPUConstantEntry *const_entries; + + // How many entries are already set; there is enough space in the allocation + // to additionally set the shared memory and dynamic block dimension entries. + int const_entries_set; +}; +typedef struct wgpu_kernel* gpu_kernel; +typedef WGPUBuffer gpu_mem; + +static int gpu_alloc_actual(struct futhark_context *ctx, + size_t size, gpu_mem *mem_out) { + // Storage buffers bindings must have an effective size that is amultiple of + // 4, so we round up all allocations. + size = ((size + 4 - 1) / 4) * 4; + WGPUBufferDescriptor desc = { + .size = size, + .usage = WGPUBufferUsage_CopySrc + | WGPUBufferUsage_CopyDst + | WGPUBufferUsage_Storage, + }; + *mem_out = wgpuDeviceCreateBuffer(ctx->device, &desc); + return FUTHARK_SUCCESS; + } + + static int gpu_free_actual(struct futhark_context *ctx, gpu_mem mem) { + (void)ctx; + wgpuBufferDestroy(mem); + return FUTHARK_SUCCESS; +} + +static void gpu_create_kernel(struct futhark_context *ctx, + gpu_kernel *kernel_out, + const char *name) { + if (ctx->debugging) { + fprintf(ctx->log, "Creating kernel %s.\n", name); + } + + struct wgpu_kernel_info *kernel_info = wgpu_get_kernel_info(name); + struct wgpu_kernel *kernel = malloc(sizeof(struct wgpu_kernel)); + kernel->info = kernel_info; + + // If this is a builtin kernel, generate the shader module here + if (kernel_info->gpu_program[0]) { + const char* wgsl = strconcat(kernel_info->gpu_program); + wgpu_module_setup(ctx, wgsl, &kernel->module, name); + free((void*)wgsl); + } + else { + kernel->module = ctx->module; + } + + WGPUBufferDescriptor scalars_desc = { + .label = "kernel scalars", + .size = kernel_info->scalars_size, + .usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst + }; + kernel->scalars_buffer = wgpuDeviceCreateBuffer(ctx->device, &scalars_desc); + + // Create bind group layout. + WGPUBindGroupLayoutEntry *bgl_entries + = calloc(1 + kernel_info->num_bindings, sizeof(WGPUBindGroupLayoutEntry)); + + WGPUBindGroupLayoutEntry *scalar_entry = bgl_entries; + scalar_entry->binding = kernel_info->scalars_binding; + scalar_entry->visibility = WGPUShaderStage_Compute; + WGPUBufferBindingLayout scalar_buffer_layout + = { .type = WGPUBufferBindingType_Uniform }; + scalar_entry->buffer = scalar_buffer_layout; + + for (int i = 0; i < kernel_info->num_bindings; i++) { + WGPUBindGroupLayoutEntry *entry = &bgl_entries[1 + i]; + entry->binding = kernel_info->binding_indices[i]; + entry->visibility = WGPUShaderStage_Compute; + WGPUBufferBindingLayout buffer_layout + = { .type = WGPUBufferBindingType_Storage }; + entry->buffer = buffer_layout; + } + WGPUBindGroupLayoutDescriptor bgl_desc = { + .entryCount = 1 + kernel_info->num_bindings, + .entries = bgl_entries + }; + kernel->bind_group_layout + = wgpuDeviceCreateBindGroupLayout(ctx->device, &bgl_desc); + free(bgl_entries); + + // Create pipeline layout. + WGPUPipelineLayoutDescriptor pl = { + .bindGroupLayoutCount = 1, + .bindGroupLayouts = &kernel->bind_group_layout, + }; + kernel->pipeline_layout = wgpuDeviceCreatePipelineLayout(ctx->device, &pl); + + // Create constants / overrides. + // TODO: We should be able to just set all overrides from the context and + // remove the used_overrides from kernel_info. It only exists because + // Chrome/Dawn currently complains if we set unused constants, see + // https://issues.chromium.org/issues/338624452. + WGPUConstantEntry *const_entries = calloc(kernel_info->num_overrides, + sizeof(WGPUConstantEntry)); + int const_idx = 0; + for (int i = 0; i < ctx->num_overrides; i++) { + for (int j = 0; j < kernel_info->num_overrides; j++) { + if (strcmp(kernel_info->used_overrides[j], ctx->override_names[i]) == 0) { + WGPUConstantEntry *entry = &const_entries[const_idx]; + entry->key = ctx->override_names[i]; + entry->value = ctx->override_values[i]; + const_idx++; + } + } + } + + kernel->static_pipeline = + kernel_info->num_dynamic_block_dims == 0 + && kernel_info->num_shared_mem_overrides == 0; + if (!kernel->static_pipeline) { + kernel->const_entries = const_entries; + kernel->const_entries_set = const_idx; + } + else { + // Create pipeline. + WGPUComputePipelineDescriptor desc = { + .layout = kernel->pipeline_layout, + .compute = { + .module = kernel_info->gpu_program[0] ? kernel->module : ctx->module, + .entryPoint = kernel_info->name, + .constantCount = kernel_info->num_overrides, + .constants = const_entries, + } + }; + + kernel->pipeline = wgpuDeviceCreateComputePipeline(ctx->device, &desc); + + free(const_entries); + } + + *kernel_out = kernel; +} + +static void gpu_free_kernel(struct futhark_context *ctx, + gpu_kernel kernel) { + (void)ctx; + wgpuBufferDestroy(kernel->scalars_buffer); + free(kernel); +} + +static int gpu_scalar_to_device(struct futhark_context *ctx, + const char *provenance, + gpu_mem dst, size_t offset, size_t size, + void *src) { + wgpuQueueWriteBuffer(ctx->queue, dst, offset, src, size); + return FUTHARK_SUCCESS; +} + +static int gpu_scalar_from_device(struct futhark_context *ctx, + const char *provenance, + void *dst, + gpu_mem src, size_t offset, size_t size) { + if (size > 8) { + futhark_panic(-1, "gpu_scalar_from_device with size %zu > 8 is not allowed\n", + size); + } + + size_t copy_size = ((size + 4 - 1) / 4) * 4; + + WGPUCommandEncoder encoder = wgpuDeviceCreateCommandEncoder(ctx->device, NULL); + wgpuCommandEncoderCopyBufferToBuffer(encoder, + src, offset, + ctx->scalar_readback_buffer, 0, + copy_size); + + WGPUCommandBuffer commandBuffer = wgpuCommandEncoderFinish(encoder, NULL); + wgpuQueueSubmit(ctx->queue, 1, &commandBuffer); + + WGPUBufferMapAsyncStatus status = + wgpu_map_buffer_sync(ctx->instance, ctx->scalar_readback_buffer, + WGPUMapMode_Read, 0, copy_size); + if (status != WGPUBufferMapAsyncStatus_Success) { + futhark_panic(-1, "gpu_scalar_from_device: Failed to read scalar from device memory with error %d\n", + status); + } + + const void *mapped = wgpuBufferGetConstMappedRange(ctx->scalar_readback_buffer, + 0, copy_size); + memcpy(dst, mapped, size); + + wgpuBufferUnmap(ctx->scalar_readback_buffer); + return FUTHARK_SUCCESS; +} + +static int memcpy_host2gpu(struct futhark_context *ctx, + const char *provenance, + bool sync, + gpu_mem dst, int64_t dst_offset, + const unsigned char *src, int64_t src_offset, + int64_t nbytes) { + if (nbytes <= 0) { return FUTHARK_SUCCESS; } + + // There is no async copy to device memory at the moment (the spec for + // `writeBuffer` specifies that a copy of the data is always made and there is + // no other good option to use here), so we ignore the sync parameter. + (void)sync; + + // Bound storage buffers and copy operations must have sizes multiple of 4. + // Note that writing more than `nbytes` is safe because we also pad all + // buffers when allocating them, but we can't guarantee that the `src` here + // has enough bytes. + // If this is a copy somewhere into the middle of `dst`, it is also possible + // we overwrite some data here, which would be bad. + int64_t copy_size = ((nbytes + 4 - 1) / 4) * 4; + if (copy_size > nbytes) { + // Potential for an issue if we're not at the end of the destination buffer. + // Find its size to make sure. + uint64_t dst_size = wgpuBufferGetSize(dst); + if (dst_offset + copy_size != dst_size) { + printf("memcpy_host2gpu: Potentially could corrupt data due to padding!\n"); + //futhark_panic(-1, "memcpy_host2gpu: Would corrupt data due to padding!\n"); + } + } + + const unsigned char *buf; + int64_t offset; + if (copy_size > nbytes) { + buf = malloc(copy_size); + offset = 0; + memcpy((unsigned char*)buf, src + src_offset, copy_size); + } + else { + buf = src; + offset = src_offset; + } + + wgpuQueueWriteBuffer(ctx->queue, dst, dst_offset, buf + offset, copy_size); + + if (buf != src) { + free((void*)buf); + } + + return FUTHARK_SUCCESS; +} + +static int memcpy_gpu2host(struct futhark_context *ctx, + const char *provenance, + bool sync, + unsigned char *dst, int64_t dst_offset, + gpu_mem src, int64_t src_offset, + int64_t nbytes) { + if (nbytes <= 0) { return FUTHARK_SUCCESS; } + + // Bound storage buffers and copy operations must have sizes multiple of 4. + // Note that mapping more than `nbytes` is safe because we also pad all + // buffers when allocating them. + int64_t buf_size = ((nbytes + 4 - 1) / 4) * 4; + + WGPUBufferDescriptor desc = { + .label = "tmp_readback", + .size = buf_size, + .usage = WGPUBufferUsage_MapRead | WGPUBufferUsage_CopyDst, + }; + WGPUBuffer readback = wgpuDeviceCreateBuffer(ctx->device, &desc); + + WGPUCommandEncoder encoder = wgpuDeviceCreateCommandEncoder(ctx->device, NULL); + wgpuCommandEncoderCopyBufferToBuffer(encoder, + src, src_offset, + readback, 0, + buf_size); + + WGPUCommandBuffer commandBuffer = wgpuCommandEncoderFinish(encoder, NULL); + wgpuQueueSubmit(ctx->queue, 1, &commandBuffer); + + // TODO: Could we do an actual async mapping here if `sync` is false? + WGPUBufferMapAsyncStatus status = + wgpu_map_buffer_sync(ctx->instance, readback, WGPUMapMode_Read, 0, buf_size); + if (status != WGPUBufferMapAsyncStatus_Success) { + futhark_panic(-1, "memcpy_gpu2host: Failed to copy from device memory with error %d\n", + status); + } + + const void *mapped = wgpuBufferGetConstMappedRange(readback, 0, buf_size); + memcpy(dst + dst_offset, mapped, nbytes); + + wgpuBufferUnmap(readback); + wgpuBufferDestroy(readback); + return FUTHARK_SUCCESS; +} + +static int gpu_memcpy(struct futhark_context *ctx, + const char *provenance, + gpu_mem dst, int64_t dst_offset, + gpu_mem src, int64_t src_offset, + int64_t nbytes) { + // Bound storage buffers and copy operations must have sizes multiple of 4. + // Note that copying more than `nbytes` is memory-safe because we also pad all + // buffers when allocating them. + // It could however corrupt data if the copy is in the middle of the buffer, + // like in host2gpu. + int64_t copy_size = ((nbytes + 4 - 1) / 4) * 4; + if (copy_size > nbytes) { + // Potential for an issue if we're not at the end of the destination buffer. + // Find its size to make sure. + uint64_t dst_size = wgpuBufferGetSize(dst); + if (dst_offset + copy_size != dst_size) { + printf("gpu_memcpy: Potentially could corrupt data due to padding!\n"); + //futhark_panic(-1, "gpu_memcpy: Would corrupt data due to padding!\n"); + } + } + WGPUCommandEncoder encoder = wgpuDeviceCreateCommandEncoder(ctx->device, NULL); + + if (dst == src) { + printf("gpu_memcpy: Cannot memcpy to/from the same buffer. Copying to temporary buffer first.\n"); + // Allocate temporary buffer. + gpu_mem tmp; + gpu_alloc_actual(ctx, copy_size, &tmp); + + // Copy data to temporary buffer and then to the destination. + wgpuCommandEncoderCopyBufferToBuffer(encoder, + src, src_offset, tmp, 0, copy_size); + wgpuCommandEncoderCopyBufferToBuffer(encoder, + tmp, 0, dst, dst_offset, copy_size); + WGPUCommandBuffer commandBuffer = wgpuCommandEncoderFinish(encoder, NULL); + wgpuQueueSubmit(ctx->queue, 1, &commandBuffer); + + // Free the temporary buffer once we have finished copying. + futhark_context_sync(ctx); + gpu_free_actual(ctx, tmp); + } + else { + wgpuCommandEncoderCopyBufferToBuffer(encoder, + src, src_offset, dst, dst_offset, copy_size); + WGPUCommandBuffer commandBuffer = wgpuCommandEncoderFinish(encoder, NULL); + wgpuQueueSubmit(ctx->queue, 1, &commandBuffer); + } + + return FUTHARK_SUCCESS; +} + +static int gpu_launch_kernel(struct futhark_context* ctx, + gpu_kernel kernel, const char *name, + const char *provenance, + const int32_t grid[3], + const int32_t block[3], + unsigned int shared_mem_bytes, + int num_args, + void* args[num_args], + size_t args_sizes[num_args]) { + struct wgpu_kernel_info *kernel_info = kernel->info; + + if (num_args != + kernel_info->num_shared_mem_overrides + + kernel_info->num_scalars + + kernel_info->num_bindings + ) { + futhark_panic(-1, "Kernel %s called with num_args not maching its info\n", + name); + } + + int shared_mem_start = 0; + int scalars_start = shared_mem_start + kernel_info->num_shared_mem_overrides; + int mem_start = scalars_start + kernel_info->num_scalars; + + void *scalars = malloc(kernel_info->scalars_size); + for (int i = 0; i < kernel_info->num_scalars; i++) { + memcpy(scalars + kernel_info->scalar_offsets[i], + args[scalars_start + i], args_sizes[scalars_start + i]); + } + + WGPUBindGroupEntry *bg_entries = calloc(1 + kernel_info->num_bindings, + sizeof(WGPUBindGroupEntry)); + for (int i = 0; i < kernel_info->num_bindings; i++) { + WGPUBindGroupEntry *entry = &bg_entries[1 + i]; + entry->binding = kernel_info->binding_indices[i]; + entry->buffer = (gpu_mem) *((gpu_mem *)args[mem_start + i]); + // In theory setting (offset, size) to (0, 0) should also work and mean + // 'the entire buffer', but as of writing this, Firefox requires + // specifying the size. + entry->offset = 0; + entry->size = wgpuBufferGetSize(entry->buffer); + } + + wgpuQueueWriteBuffer(ctx->queue, kernel->scalars_buffer, 0, + scalars, kernel_info->scalars_size); + + WGPUBindGroupEntry *scalar_entry = bg_entries; + scalar_entry->binding = kernel_info->scalars_binding; + scalar_entry->buffer = kernel->scalars_buffer; + scalar_entry->offset = 0; + scalar_entry->size = kernel_info->scalars_size; + + WGPUBindGroupDescriptor bg_desc = { + .layout = kernel->bind_group_layout, + .entryCount = 1 + kernel_info->num_bindings, + .entries = bg_entries, + }; + WGPUBindGroup bg = wgpuDeviceCreateBindGroup(ctx->device, &bg_desc); + + WGPUComputePipeline pipeline; + if (kernel->static_pipeline) { pipeline = kernel->pipeline; } + else { + int const_entry_idx = kernel->const_entries_set; + for (int i = 0; i < kernel_info->num_dynamic_block_dims; i++) { + WGPUConstantEntry *entry = &kernel->const_entries[const_entry_idx]; + const_entry_idx++; + entry->key = kernel_info->dynamic_block_dim_names[i]; + entry->value = (double) block[kernel_info->dynamic_block_dim_indices[i]]; + } + for (int i = 0; i < kernel_info->num_shared_mem_overrides; i++) { + WGPUConstantEntry *entry = &kernel->const_entries[const_entry_idx]; + const_entry_idx++; + entry->key = kernel_info->shared_mem_overrides[i]; + entry->value = (double) *((int32_t *) args[shared_mem_start + i]); + } + + WGPUComputePipelineDescriptor desc = { + .layout = kernel->pipeline_layout, + .compute = { + .module = kernel->module, + .entryPoint = kernel_info->name, + .constantCount = kernel_info->num_overrides, + .constants = kernel->const_entries, + } + }; + pipeline = wgpuDeviceCreateComputePipeline(ctx->device, &desc); + } + + WGPUCommandEncoder encoder = wgpuDeviceCreateCommandEncoder(ctx->device, NULL); + + WGPUComputePassEncoder pass_encoder + = wgpuCommandEncoderBeginComputePass(encoder, NULL); + wgpuComputePassEncoderSetPipeline(pass_encoder, pipeline); + wgpuComputePassEncoderSetBindGroup(pass_encoder, 0, bg, 0, NULL); + wgpuComputePassEncoderDispatchWorkgroups(pass_encoder, + grid[0], grid[1], grid[2]); + wgpuComputePassEncoderEnd(pass_encoder); + + WGPUCommandBuffer cmd_buffer = wgpuCommandEncoderFinish(encoder, NULL); + wgpuQueueSubmit(ctx->queue, 1, &cmd_buffer); + + free(scalars); + + return FUTHARK_SUCCESS; +} + +// End of backends/webgpu.h. diff --git a/rts/c/gpu.h b/rts/c/gpu.h index 60f8f3a2b0..fcc7d29acc 100644 --- a/rts/c/gpu.h +++ b/rts/c/gpu.h @@ -171,6 +171,7 @@ struct builtin_kernels { struct builtin_kernels* init_builtin_kernels(struct futhark_context* ctx) { struct builtin_kernels *kernels = malloc(sizeof(struct builtin_kernels)); + gpu_create_kernel(ctx, &kernels->map_transpose_1b, "map_transpose_1b"); gpu_create_kernel(ctx, &kernels->map_transpose_1b_large, "map_transpose_1b_large"); gpu_create_kernel(ctx, &kernels->map_transpose_1b_low_height, "map_transpose_1b_low_height"); @@ -354,21 +355,25 @@ static int gpu_map_transpose(struct futhark_context* ctx, void* args[11]; size_t args_sizes[11] = { - sizeof(gpu_mem), sizeof(int64_t), - sizeof(gpu_mem), sizeof(int64_t), - sizeof(int32_t), - sizeof(int32_t), - sizeof(int32_t), - sizeof(int32_t), - sizeof(int32_t) + sizeof(int64_t), // dst_offset + sizeof(int64_t), // src_offset + sizeof(int32_t), // num_arrays + sizeof(int32_t), // x_elems + sizeof(int32_t), // y_elems + sizeof(int32_t), // mulx + sizeof(int32_t), // muly + sizeof(int32_t), // repeat_1 + sizeof(int32_t), // repeat_2 + sizeof(gpu_mem), // dst + sizeof(gpu_mem) // src }; - args[0] = &dst; - args[1] = &dst_offset; - args[2] = &src; - args[3] = &src_offset; - args[7] = &mulx; - args[8] = &muly; + args[9] = &dst; + args[0] = &dst_offset; + args[10] = &src; + args[1] = &src_offset; + args[5] = &mulx; + args[6] = &muly; if (dst_offset + k * n * m <= 2147483647L && src_offset + k * n * m <= 2147483647L) { @@ -413,11 +418,11 @@ static int gpu_map_transpose(struct futhark_context* ctx, block[1] = TR_TILE_DIM/TR_ELEMS_PER_THREAD; block[2] = 1; } - args[4] = &k32; - args[5] = &m32; - args[6] = &n32; - args[7] = &mulx32; - args[8] = &muly32; + args[2] = &k32; + args[3] = &m32; + args[4] = &n32; + args[5] = &mulx32; + args[6] = &muly32; } else { if (ctx->logging) { fprintf(ctx->log, "Using large kernel\n"); } kernel = kernel_large; @@ -427,16 +432,16 @@ static int gpu_map_transpose(struct futhark_context* ctx, block[0] = TR_TILE_DIM; block[1] = TR_TILE_DIM/TR_ELEMS_PER_THREAD; block[2] = 1; - args[4] = &k; - args[5] = &m; - args[6] = &n; - args[7] = &mulx; - args[8] = &muly; + args[2] = &k; + args[3] = &m; + args[4] = &n; + args[5] = &mulx; + args[6] = &muly; + args_sizes[2] = sizeof(int64_t); + args_sizes[3] = sizeof(int64_t); args_sizes[4] = sizeof(int64_t); args_sizes[5] = sizeof(int64_t); args_sizes[6] = sizeof(int64_t); - args_sizes[7] = sizeof(int64_t); - args_sizes[8] = sizeof(int64_t); } // Cap the number of thead blocks we launch and figure out how many @@ -445,10 +450,10 @@ static int gpu_map_transpose(struct futhark_context* ctx, int32_t repeat_2 = grid[2] / MAX_TR_THREAD_BLOCKS; grid[1] = repeat_1 > 0 ? MAX_TR_THREAD_BLOCKS : grid[1]; grid[2] = repeat_2 > 0 ? MAX_TR_THREAD_BLOCKS : grid[2]; - args[9] = &repeat_1; - args[10] = &repeat_2; - args_sizes[9] = sizeof(repeat_1); - args_sizes[10] = sizeof(repeat_2); + args[7] = &repeat_1; + args[8] = &repeat_2; + args_sizes[7] = sizeof(repeat_1); + args_sizes[8] = sizeof(repeat_2); if (ctx->logging) { fprintf(ctx->log, "\n"); @@ -496,33 +501,34 @@ static int gpu_lmad_copy(struct futhark_context* ctx, const char* provenance, void* args[6+(8*3)]; size_t args_sizes[6+(8*3)]; - args[0] = &dst; - args_sizes[0] = sizeof(gpu_mem); - args[1] = &dst_offset; - args_sizes[1] = sizeof(dst_offset); - args[2] = &src; - args_sizes[2] = sizeof(gpu_mem); - args[3] = &src_offset; - args_sizes[3] = sizeof(src_offset); - args[4] = &n; - args_sizes[4] = sizeof(n); - args[5] = &r; - args_sizes[5] = sizeof(r); + args[28] = &dst; + args_sizes[28] = sizeof(gpu_mem); + args[29] = &src; + args_sizes[29] = sizeof(gpu_mem); + + args[0] = &dst_offset; + args_sizes[0] = sizeof(dst_offset); + args[1] = &src_offset; + args_sizes[1] = sizeof(src_offset); + args[2] = &n; + args_sizes[2] = sizeof(n); + args[3] = &r; + args_sizes[3] = sizeof(r); int64_t zero = 0; for (int i = 0; i < 8; i++) { - args_sizes[6+i*3] = sizeof(int64_t); - args_sizes[6+i*3+1] = sizeof(int64_t); - args_sizes[6+i*3+2] = sizeof(int64_t); + args_sizes[4+i*3] = sizeof(int64_t); + args_sizes[4+i*3+1] = sizeof(int64_t); + args_sizes[4+i*3+2] = sizeof(int64_t); if (i < r) { - args[6+i*3] = &shape[i]; - args[6+i*3+1] = &dst_strides[i]; - args[6+i*3+2] = &src_strides[i]; + args[4+i*3] = &shape[i]; + args[4+i*3+1] = &dst_strides[i]; + args[4+i*3+2] = &src_strides[i]; } else { - args[6+i*3] = &zero; - args[6+i*3+1] = &zero; - args[6+i*3+2] = &zero; + args[4+i*3] = &zero; + args[4+i*3+1] = &zero; + args[4+i*3+2] = &zero; } } const size_t w = 256; // XXX: hardcoded thread block size. diff --git a/rts/opencl/copy.cl b/rts/opencl/copy.cl index 38b0334ba4..980f1e5edd 100644 --- a/rts/opencl/copy.cl +++ b/rts/opencl/copy.cl @@ -2,9 +2,7 @@ #define GEN_COPY_KERNEL(NAME, ELEM_TYPE) \ FUTHARK_KERNEL void lmad_copy_##NAME(SHARED_MEM_PARAM \ - __global ELEM_TYPE *dst_mem, \ int64_t dst_offset, \ - __global ELEM_TYPE *src_mem, \ int64_t src_offset, \ int64_t n, \ int r, \ @@ -15,7 +13,9 @@ FUTHARK_KERNEL void lmad_copy_##NAME(SHARED_MEM_PARAM \ int64_t shape4, int64_t dst_stride4, int64_t src_stride4, \ int64_t shape5, int64_t dst_stride5, int64_t src_stride5, \ int64_t shape6, int64_t dst_stride6, int64_t src_stride6, \ - int64_t shape7, int64_t dst_stride7, int64_t src_stride7) { \ + int64_t shape7, int64_t dst_stride7, int64_t src_stride7, \ + __global ELEM_TYPE *dst_mem, \ + __global ELEM_TYPE *src_mem) { \ int64_t gtid = get_global_id(0); \ int64_t remainder = gtid; \ \ diff --git a/rts/opencl/transpose.cl b/rts/opencl/transpose.cl index 84bcd9b542..1f6ad1e445 100644 --- a/rts/opencl/transpose.cl +++ b/rts/opencl/transpose.cl @@ -3,9 +3,7 @@ #define GEN_TRANSPOSE_KERNELS(NAME, ELEM_TYPE) \ FUTHARK_KERNEL_SIZED(TR_BLOCK_DIM*2, TR_TILE_DIM/TR_ELEMS_PER_THREAD, 1)\ void map_transpose_##NAME(SHARED_MEM_PARAM \ - __global ELEM_TYPE *dst_mem, \ int64_t dst_offset, \ - __global ELEM_TYPE *src_mem, \ int64_t src_offset, \ int32_t num_arrays, \ int32_t x_elems, \ @@ -13,7 +11,9 @@ void map_transpose_##NAME(SHARED_MEM_PARAM \ int32_t mulx, \ int32_t muly, \ int32_t repeat_1, \ - int32_t repeat_2) { \ + int32_t repeat_2, \ + __global ELEM_TYPE *dst_mem, \ + __global ELEM_TYPE *src_mem) { \ (void)mulx; (void)muly; \ __local ELEM_TYPE* block = (__local ELEM_TYPE*)shared_mem; \ int tblock_id_0 = get_tblock_id(0); \ @@ -63,17 +63,17 @@ void map_transpose_##NAME(SHARED_MEM_PARAM \ \ FUTHARK_KERNEL_SIZED(TR_BLOCK_DIM, TR_BLOCK_DIM, 1) \ void map_transpose_##NAME##_low_height(SHARED_MEM_PARAM \ - __global ELEM_TYPE *dst_mem, \ - int64_t dst_offset, \ - __global ELEM_TYPE *src_mem, \ - int64_t src_offset, \ - int32_t num_arrays, \ - int32_t x_elems, \ - int32_t y_elems, \ - int32_t mulx, \ - int32_t muly, \ - int32_t repeat_1, \ - int32_t repeat_2) { \ + int64_t dst_offset, \ + int64_t src_offset, \ + int32_t num_arrays, \ + int32_t x_elems, \ + int32_t y_elems, \ + int32_t mulx, \ + int32_t muly, \ + int32_t repeat_1, \ + int32_t repeat_2, \ + __global ELEM_TYPE *dst_mem, \ + __global ELEM_TYPE *src_mem) { \ __local ELEM_TYPE* block = (__local ELEM_TYPE*)shared_mem; \ int tblock_id_0 = get_tblock_id(0); \ int global_id_0 = get_global_id(0); \ @@ -118,9 +118,7 @@ void map_transpose_##NAME##_low_height(SHARED_MEM_PARAM \ \ FUTHARK_KERNEL_SIZED(TR_BLOCK_DIM, TR_BLOCK_DIM, 1) \ void map_transpose_##NAME##_low_width(SHARED_MEM_PARAM \ - __global ELEM_TYPE *dst_mem, \ int64_t dst_offset, \ - __global ELEM_TYPE *src_mem, \ int64_t src_offset, \ int32_t num_arrays, \ int32_t x_elems, \ @@ -128,7 +126,9 @@ void map_transpose_##NAME##_low_width(SHARED_MEM_PARAM \ int32_t mulx, \ int32_t muly, \ int32_t repeat_1, \ - int32_t repeat_2) { \ + int32_t repeat_2, \ + __global ELEM_TYPE *dst_mem, \ + __global ELEM_TYPE *src_mem) { \ __local ELEM_TYPE* block = (__local ELEM_TYPE*)shared_mem; \ int tblock_id_0 = get_tblock_id(0); \ int global_id_0 = get_global_id(0); \ @@ -169,10 +169,8 @@ void map_transpose_##NAME##_low_width(SHARED_MEM_PARAM \ } \ \ FUTHARK_KERNEL_SIZED(TR_BLOCK_DIM*TR_BLOCK_DIM, 1, 1) \ -void map_transpose_##NAME##_small(SHARED_MEM_PARAM \ - __global ELEM_TYPE *dst_mem, \ +void map_transpose_##NAME##_small(SHARED_MEM_PARAM \ int64_t dst_offset, \ - __global ELEM_TYPE *src_mem, \ int64_t src_offset, \ int32_t num_arrays, \ int32_t x_elems, \ @@ -180,7 +178,9 @@ void map_transpose_##NAME##_small(SHARED_MEM_PARAM \ int32_t mulx, \ int32_t muly, \ int32_t repeat_1, \ - int32_t repeat_2) { \ + int32_t repeat_2, \ + __global ELEM_TYPE *dst_mem, \ + __global ELEM_TYPE *src_mem) { \ (void)mulx; (void)muly; \ __local ELEM_TYPE* block = (__local ELEM_TYPE*)shared_mem; \ int tblock_id_0 = get_tblock_id(0); \ @@ -211,9 +211,7 @@ void map_transpose_##NAME##_small(SHARED_MEM_PARAM \ \ FUTHARK_KERNEL_SIZED(TR_BLOCK_DIM*2, TR_TILE_DIM/TR_ELEMS_PER_THREAD, 1)\ void map_transpose_##NAME##_large(SHARED_MEM_PARAM \ - __global ELEM_TYPE *dst_mem, \ int64_t dst_offset, \ - __global ELEM_TYPE *src_mem, \ int64_t src_offset, \ int64_t num_arrays, \ int64_t x_elems, \ @@ -221,9 +219,11 @@ void map_transpose_##NAME##_large(SHARED_MEM_PARAM \ int64_t mulx, \ int64_t muly, \ int32_t repeat_1, \ - int32_t repeat_2) { \ + int32_t repeat_2, \ + __global ELEM_TYPE *dst_mem, \ + __global ELEM_TYPE *src_mem) { \ (void)mulx; (void)muly; \ - __local ELEM_TYPE* block = (__local ELEM_TYPE*)shared_mem; \ + __local ELEM_TYPE* block = (__local ELEM_TYPE*)shared_mem; \ int tblock_id_0 = get_tblock_id(0); \ int global_id_0 = get_global_id(0); \ int tblock_id_1 = get_tblock_id(1); \ diff --git a/rts/python/.gitignore b/rts/python/.gitignore new file mode 100644 index 0000000000..c18dd8d83c --- /dev/null +++ b/rts/python/.gitignore @@ -0,0 +1 @@ +__pycache__/ diff --git a/rts/webgpu/server_ws.js b/rts/webgpu/server_ws.js new file mode 100644 index 0000000000..aaedf41517 --- /dev/null +++ b/rts/webgpu/server_ws.js @@ -0,0 +1,255 @@ +// Start of server_ws.js + +// https://stackoverflow.com/a/66046176/3112547 +async function bufferToBase64(buffer) { + const base64url = await new Promise(r => { + const reader = new FileReader(); + reader.onload= () => r(reader.result); + reader.readAsDataURL(new Blob([buffer])); + }); + return base64url.slice(base64url.indexOf(',') + 1); +} + +class BrowserServer { + constructor(fut, port) { + this.fut = fut; + this.vars = {}; + + this.commands = { + 'entry_points': this.cmd_entry_points.bind(this), + 'inputs': this.cmd_inputs.bind(this), + 'outputs': this.cmd_outputs.bind(this), + 'restore': this.cmd_restore.bind(this), + 'store': this.cmd_store.bind(this), + 'free': this.cmd_free.bind(this), + 'call': this.cmd_call.bind(this), + 'clear': this.cmd_clear.bind(this), + 'report': this.cmd_report.bind(this), + 'pause_profiling': this.cmd_pause_profiling.bind(this), + 'unpause_profiling': this.cmd_unpause_profiling.bind(this), + }; + + this.socket = new WebSocket("ws://" + window.location.host + "/ws"); + this.socket.onmessage = async (event) => { + const msg = JSON.parse(event.data); + console.log("WS command:", msg); + + let resp = undefined; + try { + if (!(msg.cmd in this.commands)) { + throw "Unknown command: " + msg.cmd; + } + + const fun = this.commands[msg.cmd]; + const res = await fun(...msg.args); + + await this.fut.context_sync(); + + if (typeof res == "string") { + resp = { status: "ok", text: res }; + } + else { + res['status'] = "ok"; + resp = res; + } + } catch (ex) { + console.log(ex); + resp = { status: "fail", text: ex.toString() }; + } + this.socket.send(JSON.stringify(resp)); + }; + console.log("Created WS client."); + } + + get_entry_point(entry) { + if (entry in this.fut.entry) { + return this.fut.entry[entry]; + } + throw "Unknown entry point: " + entry; + } + + get_manifest_entry_point(entry) { + if (entry in this.fut.manifest.entry_points) { + return this.fut.manifest.entry_points[entry]; + } + throw "Unknown entry point: " + entry; + } + + get_manifest_type(type) { + if (type in this.fut.manifest.types) { + return this.fut.manifest.types[type]; + } + throw "Unknown type: " + type; + } + + check_var(name) { + if (!(name in this.vars)) { + throw "Unknown variable: " + name; + } + } + + set_var(name, val, typ) { + this.vars[name] = {val: val, typ: typ}; + } + + get_var(name) { + this.check_var(name); + return this.vars[name]; + } + + delete_var(name) { + delete this.vars[name]; + } + + async cmd_entry_points() { + const entries = Object.keys(this.fut.available_entry_points); + return entries.join("\n"); + } + + async cmd_inputs(entry) { + const entry_info = this.get_manifest_entry_point(entry); + const inputs = entry_info.inputs.map(function(arg) { + if (arg.unique) { return "*" + arg.type; } + return arg.type; + }); + return inputs.join("\n"); + } + + async cmd_outputs(entry) { + const entry_info = this.get_manifest_entry_point(entry); + const outputs = entry_info.outputs.map(function(arg) { + if (arg.unique) { return "*" + arg.type; } + return arg.type; + }); + return outputs.join("\n"); + } + + async cmd_restore(file, ...varsAndTypes) { + // Request file from the server. + const data = await fetch(file).then((response) => { + if (!response.ok) { + throw "Failed to fetch file: " + response.statusText; + } + return response.bytes(); + }); + + const reader = new FutharkReader(data); + + for (let i = 0; i < varsAndTypes.length; i += 2) { + const name = varsAndTypes[i]; + const type = varsAndTypes[i+1]; + + const raw_val = reader.read_value(type); + + let val = undefined; + if (type in this.fut.manifest.types) { + const type_info = this.get_manifest_type(type); + futhark_assert(type_info.kind == "array"); + + const [data, shape] = raw_val; + val = this.fut.types[type].from_data(data, ...shape); + } + else { + // Scalar. + val = raw_val; + } + + this.set_var(name, val, type); + } + + return ""; + } + + async cmd_store(...vars) { + let data = ""; + let types = []; + for (const name of vars) { + const {val, typ} = this.get_var(name); + + let to_write = undefined; + if (typ in this.fut.manifest.types) { + const type_info = this.get_manifest_type(typ); + futhark_assert(type_info.kind == "array"); + + const values = await val.values(); + const shape = val.get_shape(); + to_write = [values, shape]; + } + else { + // Scalar. + to_write = val; + } + + const encoded = new FutharkWriter().encode_value(to_write, typ); + data += await bufferToBase64(encoded); + + types.push(typ); + } + + return {'data': data, 'types': types}; + } + + async cmd_free(name) { + const {val, typ} = this.get_var(name); + if (val instanceof FutharkArray) { + val.free(); + } + this.delete_var(name); + return ""; + } + + async cmd_call(entry, ...outsAndIns) { + const entry_info = this.get_manifest_entry_point(entry); + const entry_fun = this.get_entry_point(entry); + const outCount = entry_info.outputs.length; + const outNames = outsAndIns.slice(0, outCount); + const inNames = outsAndIns.slice(outCount, outsAndIns.length); + const ins = inNames.map((n) => this.get_var(n).val); + + const startTime = performance.now(); + const outs = await entry_fun(...ins); + await this.fut.context_sync(); + const endTime = performance.now(); + + for (let i = 0; i < outNames.length; i++) { + this.set_var(outNames[i], outs[i], entry_info.outputs[i].type); + } + + return "runtime: " + Math.round((endTime - startTime) * 1000).toString(); + } + + async cmd_clear() { + await this.fut.clear_caches(); + return ""; + } + + async cmd_report() { + return await this.fut.report(); + } + + async cmd_pause_profiling() { + await this.fut.pause_profiling(); + return ""; + } + + async cmd_unpause_profiling() { + await this.fut.unpause_profiling(); + return ""; + } +} + +async function runServer() { + const m = await Module(); + + // Setting fut into the global scope makes debugging a bit easier, and this is + // not intended to be embedded into anything other than the internal + // `futhark test` / `futhark bench` support anyway. + window.fut = new FutharkModule(); + await fut.init(m); + + window.server = new BrowserServer(fut); +} + +runServer(); + +// End of server_ws.js diff --git a/rts/webgpu/util.js b/rts/webgpu/util.js new file mode 100644 index 0000000000..c9a2d3b335 --- /dev/null +++ b/rts/webgpu/util.js @@ -0,0 +1,37 @@ +// Start of util.js + +function futhark_assert(condition, message) { + if (!condition) { + throw new Error(message || "Assertion failed"); + } +} + +function make_prim_info(tag, size, scalar_type, array_type, create_array, get_heap) { + return { + tag: tag, // tag used in the binary data format + size: size, + scalar_type: scalar_type, + array_type: array_type, + get_heap: get_heap, + create_array: create_array, + }; +} + +const primInfos = { + 'bool': make_prim_info("bool", 1, Boolean, Uint8Array, (h, ...args) => new Uint8Array(h, ...args), (m) => m.HEAPU8), + 'u8': make_prim_info(" u8", 1, Number, Uint8Array, (h, ...args) => new Uint8Array(h, ...args), (m) => m.HEAPU8), + 'i8': make_prim_info(" i8", 1, Number, Int8Array, (h, ...args) => new Int8Array(h, ...args), (m) => m.HEAP8), + 'u16': make_prim_info(" u16", 2, Number, Uint16Array, (h, ...args) => new Uint16Array(h, ...args), (m) => m.HEAPU16), + 'i16': make_prim_info(" i16", 2, Number, Int16Array, (h, ...args) => new Int16Array(h, ...args), (m) => m.HEAP16), + 'u32': make_prim_info(" u32", 4, Number, Uint32Array, (h, ...args) => new Uint32Array(h, ...args), (m) => m.HEAPU32), + 'i32': make_prim_info(" i32", 4, Number, Int32Array, (h, ...args) => new Int32Array(h, ...args), (m) => m.HEAP32), + 'u64': make_prim_info(" u64", 8, BigInt, BigUint64Array, (h, ...args) => new BigUint64Array(h, ...args), (m) => m.HEAPU64), + 'i64': make_prim_info(" i64", 8, BigInt, BigInt64Array, (h, ...args) => new BigInt64Array(h, ...args), (m) => m.HEAP64), + // There is no WASM heap for f16 values since Float16Array was only recently (april 2025) made available in browser baselines, + // so we have to do this ugly workaround to reinterpret Uint16 bytes as Float16 when reading from the WASM HEAPU16... + 'f16': make_prim_info(" f16", 2, Number, Float16Array, (h, ...args) => new Float16Array(new Uint16Array(h, ...args).buffer), (m) => m.HEAPU16), + 'f32': make_prim_info(" f32", 4, Number, Float32Array, (h, ...args) => new Float32Array(h, ...args), (m) => m.HEAPF32), + 'f64': make_prim_info( "f64", 8, Number, Float64Array, (h, ...args) => new Float64Array(h, ...args), (m) => m.HEAPF64), +}; + +// End of util.js diff --git a/rts/webgpu/values.js b/rts/webgpu/values.js new file mode 100644 index 0000000000..8611915053 --- /dev/null +++ b/rts/webgpu/values.js @@ -0,0 +1,152 @@ +// Start of values.js + +const futhark_binary_format_version = 2; + +class FutharkReader { + constructor(buf) { + futhark_assert(buf instanceof Uint8Array); + this.buf = buf; + } + + seek(n) { + this.buf = this.buf.subarray(n); + } + + read_byte() { + const b = this.buf[0]; + this.seek(1); + return b; + } + + read_i64() { + const buf = new Uint8Array(this.buf.subarray(0, 8)); + const val = new BigInt64Array(buf.buffer, 0, 1)[0]; + this.seek(8); + return val; + } + + read_value(expected_type = undefined) { + let off = 0; + while (this.is_whitespace(this.buf[off])) off++; + this.seek(off); + + futhark_assert(this.read_byte() == this.byte_val('b'), + "Expected binary input"); + futhark_assert(this.read_byte() == futhark_binary_format_version, + "Can only read binary format version " + futhark_binary_format_version); + + const rank = this.read_byte(); + + const type = String.fromCodePoint(...this.buf.slice(0, 4)).trimStart(); + futhark_assert(type in primInfos, "Unknown type: " + type); + this.seek(4); + + if (expected_type != undefined) { + if (rank == 0 && expected_type != type) { + throw new Error(`Read unexpected type '${rank}d ${type}', expected ${expected_type}`); + } + if (rank > 0) { + let expected_rank = 0; + let rem_type = expected_type; + while (rem_type.startsWith("[]")) { + expected_rank++; + rem_type = rem_type.slice(2); + } + + if (rank != expected_rank || type != rem_type) { + throw new Error(`Read unexpected type '${rank}d ${type}', expected ${expected_type}`); + } + } + } + + let shape = []; + for (let i = 0; i < rank; i++) { + shape.push(this.read_i64()); + } + + if (rank == 0) { + const [val, _] = this.read_array(type, [1n]); + return val[0]; + } + else { + return this.read_array(type, shape); + } + } + + read_array(type, shape) { + const type_info = primInfos[type]; + const flat_len = Number(shape.reduce((a, b) => a * b)); + + const buf = new Uint8Array(this.buf.subarray(0, flat_len * type_info.size)); + const wrapper = new type_info.array_type(buf.buffer, 0, flat_len); + + this.seek(wrapper.byteLength); + return [wrapper, shape]; + } + + is_whitespace(b) { + const whitespace = [' ', '\t', '\n'].map((c) => this.byte_val(c)); + return b in whitespace; + } + + byte_val(c) { return c.charCodeAt(0); } +} + +class FutharkWriter { + encode_value(val, type) { + let elem_type = undefined; + let rank = 0; + let flat_len = 0; + + if (type in primInfos) { + elem_type = type; + rank = 0; + flat_len = 1; + } + else { + elem_type = type.replaceAll("[]", ""); + const [data, shape] = val; + rank = shape.length; + flat_len = Number(shape.reduce((a, b) => a * b)); + } + + const prim_info = primInfos[elem_type]; + const header_size = 3 + 4; + const total_size = header_size + rank * 8 + flat_len * prim_info.size; + + const buf = new Uint8Array(total_size); + buf[0] = this.byte_val('b'); + buf[1] = futhark_binary_format_version; + buf[2] = rank; + + const tag = Uint8Array.from(prim_info.tag, c => c.charCodeAt(0)); + buf.set(tag, 3); + + let offset = header_size; + + let data = undefined; + let shape = undefined; + if (rank == 0) { + data = new prim_info.array_type([val]); + shape = []; + } + else { + const [d, s] = val; + data = d; + shape = s; + } + + const dims = new BigInt64Array(shape); + buf.set(new Uint8Array(dims.buffer), offset); + offset += dims.byteLength; + + const bin_data = new Uint8Array(data.buffer, data.byteOffset, data.byteLength); + buf.set(bin_data, offset); + + return buf; + } + + byte_val(c) { return c.charCodeAt(0); } +} + +// End of values.js diff --git a/rts/webgpu/wrappers.js b/rts/webgpu/wrappers.js new file mode 100644 index 0000000000..f847d9b287 --- /dev/null +++ b/rts/webgpu/wrappers.js @@ -0,0 +1,171 @@ +// Start of wrappers.js + +// All of the functionality is in subclasses for the individual array types, +// which are generated into fields of the FutharkModule class +// (e.g. `fut.i32_1d` if `fut` is the FutharkModule instance). +// This is just used as a marker so we can check if some object is an instance +// of any of those generated classes. +class FutharkArray { + constructor(name, arr, shape) { + // Name is only for debugging since the debugger will show + // 'FutharkArrayImpl' as type for all array types. + this.type_name = name; + this.arr = arr; + this.shape = shape; + } +} + +function make_array_class(fut, name) { + const type_info = fut.manifest.types[name]; + const prim_info = primInfos[type_info.elemtype]; + + function wasm_fun(full_name) { + const name = "_" + full_name; + return fut.m[name]; + } + + return class FutharkArrayImpl extends FutharkArray { + constructor(arr, shape) { + super(name, arr, shape); + } + + static from_native(arr) { + const shape_fun = wasm_fun(type_info.ops.shape); + const shape_ptr = shape_fun(fut.ctx, arr); + + const shape = new BigInt64Array( + fut.m.HEAP64.subarray(shape_ptr / 8, shape_ptr / 8 + type_info.rank)); + + return new FutharkArrayImpl(arr, shape); + } + + static from_data(data, ...shape) { + futhark_assert(shape.length == type_info.rank, "wrong number of shape arguments"); + if (typeof(shape[0]) === 'number') { + shape = BigInt64Array.from(shape.map((x) => BigInt(x))); + } + + if (data instanceof Array) { + data = prim_info.create_array(data); + } + futhark_assert(data instanceof prim_info.array_type, + "expected Array or correct TypedArray"); + + const wasm_data = fut.malloc(data.byteLength); + const wasm_view = fut.m.HEAPU8.subarray(wasm_data, wasm_data + data.byteLength); + wasm_view.set(new Uint8Array(data.buffer, data.byteOffset, data.byteLength)); + + const new_fun = wasm_fun(type_info.ops.new); + const arr = new_fun(fut.ctx, wasm_data, ...shape); + + fut.free(wasm_data); + + return new FutharkArrayImpl(arr, shape); + } + + get_shape() { return this.shape; } + + async values() { + futhark_assert(this.arr != undefined, "array already freed"); + + const flat_len = Number(this.shape.reduce((a, b) => a * b)); + const flat_size = flat_len * prim_info.size; + const wasm_data = fut.malloc(flat_size); + + await fut.m.ccall(type_info.ops.values, + 'number', ['number', 'number', 'number'], + [fut.ctx, this.arr, wasm_data], + {async: true}); + + const data = prim_info.create_array( + prim_info.get_heap(fut.m) + .subarray(wasm_data / prim_info.size, + wasm_data / prim_info.size + flat_len) + ); + + fut.free(wasm_data); + + return data; + }; + + free() { + const free_fun = wasm_fun(type_info.ops.free); + free_fun(fut.ctx, this.arr); + this.arr = undefined; + } + }; +} + +function make_entry_function(fut, name) { + const entry_info = fut.manifest.entry_points[name]; + + return async function(...inputs) { + futhark_assert(inputs.length == entry_info.inputs.length, + "Unexpected number of input arguments"); + + let real_inputs = []; + + for (let i = 0; i < inputs.length; i++) { + const typ = entry_info.inputs[i].type; + if (typ in primInfos) { + real_inputs.push(primInfos[typ].scalar_type(inputs[i])); + } + else if (typ in fut.manifest.types) { + const type_info = fut.manifest.types[typ]; + if (type_info.kind == "array") { + if (!(inputs[i] instanceof FutharkArray)) { + throw new Error("Entry point array arguments must be FutharkArrays"); + } + real_inputs.push(inputs[i].arr); + } + else { + real_inputs.push(inputs[i]); + } + } + else { + throw new Error("Unknown input type"); + } + } + + let out_ptrs = []; + for (let i = 0; i < entry_info.outputs.length; i++) { + out_ptrs.push(fut.malloc(4)); + } + + await fut.m.ccall(entry_info.cfun, 'number', + Array(1 + out_ptrs.length + real_inputs.length).fill('number'), + [fut.ctx].concat(out_ptrs).concat(real_inputs), {async: true}); + + let outputs = []; + for (let i = 0; i < out_ptrs.length; i++) { + const out_info = entry_info.outputs[i]; + if (out_info.type in primInfos) { + const prim_info = primInfos[out_info.type]; + const val = prim_info.get_heap(fut.m)[out_ptrs[i] / prim_info.size]; + outputs.push(val); + } + else if (out_info.type in fut.manifest.types) { + const type_info = fut.manifest.types[out_info.type]; + if (type_info.kind == "array") { + const array_type = fut.types[out_info.type]; + const val = array_type.from_native(fut.m.HEAP32[out_ptrs[i] / 4]); + outputs.push(val); + } + else { + outputs.push(val); + } + } + else { + throw new Error("Unknown output type"); + } + } + + for (const ptr of out_ptrs) { + fut.free(ptr); + } + + return outputs; + }; +} + +// End of wrappers.js diff --git a/rts/wgsl/atomics.wgsl b/rts/wgsl/atomics.wgsl new file mode 100644 index 0000000000..230080ff59 --- /dev/null +++ b/rts/wgsl/atomics.wgsl @@ -0,0 +1,547 @@ +// Start of atomics.wgsl + +//// atomic read and writes //// + +fn atomic_read_i8_global(p: ptr, read_write>, offset: i32) -> i8 { + let v: i32 = atomicLoad(p); + return norm_i8(v >> bitcast(offset * 8)); +} + +fn atomic_read_i8_shared(p: ptr>) -> i8 { + return norm_i8(atomicLoad(p)); +} + +fn atomic_write_i8_global(p: ptr, read_write>, offset: i32, val: i8) { + let shift_amt = bitcast(offset * 8); + + let mask = 0xff << shift_amt; + let shifted_val = (val << shift_amt) & mask; + + // Note: Despite relaxed semantics, this CAS loop is safe, since we are still + // sequentially consistent since all ops are operating on the same address. + var x = atomicLoad(p); + while (!atomicCompareExchangeWeak(p, x, (x & ~mask) | shifted_val).exchanged) { + x = atomicLoad(p); + } +} + +fn atomic_write_i8_shared(p: ptr>, x: i8) { + atomicStore(p, norm_u8(x)); +} + +fn atomic_read_bool_global(p: ptr, read_write>, offset: i32) -> bool { + return atomic_read_i8_global(p, offset) == 1; +} + +fn atomic_read_bool_shared(p: ptr>) -> bool { + return atomic_read_i8_shared(p) == 1; +} + +fn atomic_write_bool_global(p: ptr, read_write>, offset: i32, val: bool) { + if val { + atomic_write_i8_global(p, offset, 1); + } else { + atomic_write_i8_global(p, offset, 0); + } +} + +fn atomic_write_bool_shared(p: ptr>, x: bool) { + if x { + atomic_write_i8_shared(p, 1); + } else { + atomic_write_i8_shared(p, 0); + } +} + +fn atomic_read_i16_global(p: ptr, read_write>, offset: i32) -> i16 { + let v: i32 = atomicLoad(p); + return norm_i16(v >> bitcast(offset * 16)); +} + +fn atomic_read_i16_shared(p: ptr>) -> i16 { + return norm_i16(atomicLoad(p)); +} + +fn atomic_write_i16_global(p: ptr, read_write>, offset: i32, val: i16) { + let shift_amt = bitcast(offset * 16); + + let mask = 0xffff << shift_amt; + let shifted_val = (val << shift_amt) & mask; + + // Note: Despite relaxed semantics, this CAS loop is safe, since we are still + // sequentially consistent since all ops are operating on the same address. + var x = atomicLoad(p); + while (!atomicCompareExchangeWeak(p, x, (x & ~mask) | shifted_val).exchanged) { + x = atomicLoad(p); + } +} + +fn atomic_write_i16_shared(p: ptr>, x: i16) { + atomicStore(p, norm_u16(x)); +} + +fn atomic_read_i32_global(p: ptr, read_write>, offset: i32) -> i32 { + return atomicLoad(p); +} + +fn atomic_read_i32_shared(p: ptr>) -> i32 { + return atomicLoad(p); +} + +fn atomic_write_i32_global(p: ptr, read_write>, offset: i32, val: i32) { + atomicStore(p, val); +} + +fn atomic_write_i32_shared(p: ptr>, x: i32) { + atomicStore(p, x); +} + +//// f32 atomics //// + +fn atomic_read_f32_global(p: ptr, read_write>, offset: i32) -> f32 { + return bitcast(atomicLoad(p)); +} + +fn atomic_read_f32_shared(p: ptr>) -> f32 { + return bitcast(atomicLoad(p)); +} + +fn atomic_write_f32_global(p: ptr, read_write>, offset: i32, val: f32) { + atomicStore(p, bitcast(val)); +} + +fn atomic_write_f32_shared(p: ptr>, x: f32) { + atomicStore(p, bitcast(x)); +} + +fn atomic_fadd_f32_global(p: ptr, read_write>, offset: i32, x: f32) -> f32 { + var old: f32 = x; + var ret: f32; + + while (old != 0) { + ret = bitcast(atomicExchange(p, 0)) + old; + old = bitcast(atomicExchange(p, bitcast(ret))); + } + + return ret; +} + +fn atomic_fadd_f32_shared(p: ptr>, x: f32) -> f32 { + var old: f32 = x; + var ret: f32; + + while (old != 0) { + ret = bitcast(atomicExchange(p, 0)) + old; + old = bitcast(atomicExchange(p, bitcast(ret))); + } + + return ret; +} + +/// f16 atomics /// + +// TODO: Should just be packing f16's in shared memory as well +fn atomic_read_f16_global(p: ptr, read_write>, offset: i32) -> f16 { + return bitcast>(atomicLoad(p))[offset]; +} + +fn atomic_read_f16_shared(p: ptr>) -> f16 { + return bitcast>(atomicLoad(p))[0]; +} + +fn atomic_write_f16_global(p: ptr, read_write>, offset: i32, val: f16) { + var x = bitcast>(atomicLoad(p)); + var y = x; y[offset] = val; + while (!atomicCompareExchangeWeak(p, bitcast(x), bitcast(y)).exchanged) { + x = bitcast>(atomicLoad(p)); + y = x; y[offset] = val;; + } +} + +fn atomic_write_f16_shared(p: ptr>, x: f16) { + atomicStore(p, bitcast(vec2(x))); +} + +fn atomic_fadd_f16_global(p: ptr, read_write>, offset: i32, x: f16) -> f16 { + var val = vec2(0.0); val[offset] = x; + var old = bitcast>(atomicLoad(p)); + while (!atomicCompareExchangeWeak(p, bitcast(old), bitcast(old + val)).exchanged) { + old = bitcast>(atomicLoad(p)); + } + + return old[offset]; +} + +fn atomic_fadd_f16_shared(p: ptr>, x: f16) -> f16 { + var val = vec2(0.0); val[0] = x; + var old = bitcast>(atomicLoad(p)); + while (!atomicCompareExchangeWeak(p, bitcast(old), bitcast(old + val)).exchanged) { + old = bitcast>(atomicLoad(p)); + } + + return old[0]; +} + +//// i32 atomics //// + +fn atomic_add_i32_global(p: ptr, read_write>, offset: i32, x: i32) -> i32 { + return atomicAdd(p, x); +} + +fn atomic_add_i32_shared(p: ptr>, x: i32) -> i32 { + return atomicAdd(p, x); +} + +fn atomic_smax_i32_global(p: ptr, read_write>, offset: i32, x: i32) -> i32 { + return atomicMax(p, x); +} + +fn atomic_smax_i32_shared(p: ptr>, x: i32) -> i32 { + return atomicMax(p, x); +} + +fn atomic_smin_i32_global(p: ptr, read_write>, offset: i32, x: i32) -> i32 { + return atomicMin(p, x); +} + +fn atomic_smin_i32_shared(p: ptr>, x: i32) -> i32 { + return atomicMin(p, x); +} + +fn atomic_umax_i32_global(p: ptr, read_write>, offset: i32, x: i32) -> i32 { + var old = atomicLoad(p); + while (!atomicCompareExchangeWeak(p, old, umax_i32(old, x)).exchanged) { + old = atomicLoad(p); + } + + return old; +} + +fn atomic_umax_i32_shared(p: ptr>, x: i32) -> i32 { + var old = atomicLoad(p); + while (!atomicCompareExchangeWeak(p, old, umax_i32(old, x)).exchanged) { + old = atomicLoad(p); + } + + return old; +} + +fn atomic_umin_i32_global(p: ptr, read_write>, offset: i32, x: i32) -> i32 { + var old = atomicLoad(p); + while (!atomicCompareExchangeWeak(p, old, umin_i32(old, x)).exchanged) { + old = atomicLoad(p); + } + + return old; +} + +fn atomic_umin_i32_shared(p: ptr>, x: i32) -> i32 { + var old = atomicLoad(p); + while (!atomicCompareExchangeWeak(p, old, umin_i32(old, x)).exchanged) { + old = atomicLoad(p); + } + + return old; +} + +fn atomic_and_i32_global(p: ptr, read_write>, offset: i32, x: i32) -> i32 { + return atomicAnd(p, x); +} + +fn atomic_and_i32_shared(p: ptr>, x: i32) -> i32 { + return atomicAnd(p, x); +} + +fn atomic_or_i32_global(p: ptr, read_write>, offset: i32, x: i32) -> i32 { + return atomicOr(p, x); +} + +fn atomic_or_i32_shared(p: ptr>, x: i32) -> i32 { + return atomicOr(p, x); +} + +fn atomic_xor_i32_global(p: ptr, read_write>, offset: i32, x: i32) -> i32 { + return atomicXor(p, x); +} + +fn atomic_xor_i32_shared(p: ptr>, x: i32) -> i32 { + return atomicXor(p, x); +} + +//// i16 atomics //// + +fn atomic_add_i16_global(p: ptr, read_write>, offset: i32, x: i16) -> i16 { + loop { + let old = atomicLoad(p); + let old_i16 = norm_i16(old >> bitcast(offset * 16)); + let val = i32(add_i16(old_i16, norm_i16(x))) << bitcast(offset * 16); + let rest = old & ~(0xffff << bitcast(offset * 16)); + + if (atomicCompareExchangeWeak(p, old, val | rest).exchanged) { + return old_i16; + } + } +} + +fn atomic_add_i16_shared(p: ptr>, x: i16) -> i16 { + var old = atomicLoad(p); + while (!atomicCompareExchangeWeak(p, old, add_i16(norm_i16(old), norm_i16(x))).exchanged) { + old = atomicLoad(p); + } + + return norm_i16(old); +} + +fn atomic_smax_i16_global(p: ptr, read_write>, offset: i32, x: i16) -> i16 { + loop { + let old = atomicLoad(p); + let old_i16 = norm_i16(old >> bitcast(offset * 16)); + let val = norm_u16(max(old_i16, norm_i16(x))) << bitcast(offset * 16); + let rest = old & ~(0xffff << bitcast(offset * 16)); + + if (atomicCompareExchangeWeak(p, old, val | rest).exchanged) { + return old_i16; + } + } +} + +fn atomic_smax_i16_shared(p: ptr>, x: i16) -> i16 { + var old = atomicLoad(p); + while (!atomicCompareExchangeWeak(p, old, max(norm_i16(old), norm_i16(x))).exchanged) { + old = atomicLoad(p); + } + + return norm_i16(old); +} + +fn atomic_smin_i16_global(p: ptr, read_write>, offset: i32, x: i16) -> i16 { + loop { + let old = atomicLoad(p); + let old_i16 = norm_i16(old >> bitcast(offset * 16)); + let val = norm_u16(min(old_i16, norm_i16(x))) << bitcast(offset * 16); + let rest = old & ~(0xffff << bitcast(offset * 16)); + + if (atomicCompareExchangeWeak(p, old, val | rest).exchanged) { + return old_i16; + } + } +} + +fn atomic_smin_i16_shared(p: ptr>, x: i16) -> i16 { + var old = atomicLoad(p); + while (!atomicCompareExchangeWeak(p, old, min(norm_i16(old), norm_i16(x))).exchanged) { + old = atomicLoad(p); + } + + return norm_i16(old); +} + +fn atomic_umax_i16_global(p: ptr, read_write>, offset: i32, x: i16) -> i16 { + loop { + let old = atomicLoad(p); + let old_u16 = norm_u16(old >> bitcast(offset * 16)); + let val = norm_u16(umax_i16(old_u16, norm_u16(x))) << bitcast(offset * 16); + let rest = old & ~(0xffff << bitcast(offset * 16)); + + if (atomicCompareExchangeWeak(p, old, val | rest).exchanged) { + return old_u16; + } + } +} + +fn atomic_umax_i16_shared(p: ptr>, x: i16) -> i16 { + var old = atomicLoad(p); + while (!atomicCompareExchangeWeak(p, old, umax_i16(norm_u16(old), norm_u16(x))).exchanged) { + old = atomicLoad(p); + } + + return norm_u16(old); +} + +fn atomic_umin_i16_global(p: ptr, read_write>, offset: i32, x: i16) -> i16 { + loop { + let old = atomicLoad(p); + let old_u16 = norm_u16(old >> bitcast(offset * 16)); + let val = norm_u16(umin_i16(old_u16, norm_u16(x))) << bitcast(offset * 16); + let rest = old & ~(0xffff << bitcast(offset * 16)); + + if (atomicCompareExchangeWeak(p, old, val | rest).exchanged) { + return old_u16; + } + } +} + +fn atomic_umin_i16_shared(p: ptr>, x: i16) -> i16 { + var old = atomicLoad(p); + while (!atomicCompareExchangeWeak(p, old, umin_i16(norm_u16(old), norm_u16(x))).exchanged) { + old = atomicLoad(p); + } + + return norm_u16(old); +} + +fn atomic_and_i16_global(p: ptr, read_write>, offset: i32, x: i16) -> i16 { + let shift = bitcast(offset * 16); + let mask = 0xffff << shift; + return norm_u16(atomicAnd(p, ~mask | (norm_u16(x) << shift)) >> shift); +} + +fn atomic_and_i16_shared(p: ptr>, x: i16) -> i16 { + return norm_u16(atomicAnd(p, ~0xffff | norm_u16(x))); +} + +fn atomic_or_i16_global(p: ptr, read_write>, offset: i32, x: i16) -> i16 { + return norm_u16(atomicOr(p, norm_u16(x) << bitcast(offset * 16)) >> bitcast(offset * 16)); +} + +fn atomic_or_i16_shared(p: ptr>, x: i16) -> i16 { + return norm_u16(atomicOr(p, norm_u16(x))); +} + +fn atomic_xor_i16_global(p: ptr, read_write>, offset: i32, x: i16) -> i16 { + return norm_u16(atomicXor(p, norm_u16(x) << bitcast(offset * 16)) >> bitcast(offset * 16)); +} + +fn atomic_xor_i16_shared(p: ptr>, x: i16) -> i16 { + return norm_u16(atomicXor(p, norm_u16(x))); +} + +//// i8 atomics //// + +fn atomic_add_i8_global(p: ptr, read_write>, offset: i32, x: i8) -> i8 { + loop { + let old = atomicLoad(p); + let old_i8 = norm_i8(old >> bitcast(offset * 8)); + let val = i32(add_i8(old_i8, norm_i8(x))) << bitcast(offset * 8); + let rest = old & ~(0xff << bitcast(offset * 8)); + + if (atomicCompareExchangeWeak(p, old, val | rest).exchanged) { + return old_i8; + } + } +} + +fn atomic_add_i8_shared(p: ptr>, x: i8) -> i8 { + var old = atomicLoad(p); + while (!atomicCompareExchangeWeak(p, old, add_i8(norm_i8(old), norm_i8(x))).exchanged) { + old = atomicLoad(p); + } + + return norm_i8(old); +} + +fn atomic_smax_i8_global(p: ptr, read_write>, offset: i32, x: i8) -> i8 { + loop { + let old = atomicLoad(p); + let old_i8 = norm_i8(old >> bitcast(offset * 8)); + let val = norm_u8(max(old_i8, norm_i8(x))) << bitcast(offset * 8); + let rest = old & ~(0xff << bitcast(offset * 8)); + + if (atomicCompareExchangeWeak(p, old, val | rest).exchanged) { + return old_i8; + } + } +} + +fn atomic_smax_i8_shared(p: ptr>, x: i8) -> i8 { + var old = atomicLoad(p); + while (!atomicCompareExchangeWeak(p, old, max(norm_i8(old), norm_i8(x))).exchanged) { + old = atomicLoad(p); + } + + return norm_i8(old); +} + +fn atomic_smin_i8_global(p: ptr, read_write>, offset: i32, x: i8) -> i8 { + loop { + let old = atomicLoad(p); + let old_i8 = norm_i8(old >> bitcast(offset * 8)); + let val = norm_u8(min(old_i8, norm_i8(x))) << bitcast(offset * 8); + let rest = old & ~(0xff << bitcast(offset * 8)); + + if (atomicCompareExchangeWeak(p, old, val | rest).exchanged) { + return old_i8; + } + } +} + +fn atomic_smin_i8_shared(p: ptr>, x: i8) -> i8 { + var old = atomicLoad(p); + while (!atomicCompareExchangeWeak(p, old, min(norm_i8(old), norm_i8(x))).exchanged) { + old = atomicLoad(p); + } + + return norm_i8(old); +} + +fn atomic_umax_i8_global(p: ptr, read_write>, offset: i32, x: i8) -> i8 { + loop { + let old = atomicLoad(p); + let old_u8 = norm_u8(old >> bitcast(offset * 8)); + let val = norm_u8(umax_i8(old_u8, norm_u8(x))) << bitcast(offset * 8); + let rest = old & ~(0xff << bitcast(offset * 8)); + + if (atomicCompareExchangeWeak(p, old, val | rest).exchanged) { + return old_u8; + } + } +} + +fn atomic_umax_i8_shared(p: ptr>, x: i8) -> i8 { + var old = atomicLoad(p); + while (!atomicCompareExchangeWeak(p, old, umax_i8(norm_u8(old), norm_u8(x))).exchanged) { + old = atomicLoad(p); + } + + return norm_u8(old); +} + +fn atomic_umin_i8_global(p: ptr, read_write>, offset: i32, x: i8) -> i8 { + loop { + let old = atomicLoad(p); + let old_u8 = norm_u8(old >> bitcast(offset * 8)); + let val = norm_u8(umin_i8(old_u8, norm_u8(x))) << bitcast(offset * 8); + let rest = old & ~(0xff << bitcast(offset * 8)); + + if (atomicCompareExchangeWeak(p, old, val | rest).exchanged) { + return old_u8; + } + } +} + +fn atomic_umin_i8_shared(p: ptr>, x: i8) -> i8 { + var old = atomicLoad(p); + while (!atomicCompareExchangeWeak(p, old, umin_i8(norm_u8(old), norm_u8(x))).exchanged) { + old = atomicLoad(p); + } + + return norm_u8(old); +} + +fn atomic_and_i8_global(p: ptr, read_write>, offset: i32, x: i8) -> i8 { + let shift = bitcast(offset * 8); + let mask = 0xff << shift; + return norm_u8(atomicAnd(p, ~mask | (norm_u8(x) << shift)) >> shift); +} + +fn atomic_and_i8_shared(p: ptr>, x: i8) -> i8 { + return norm_u8(atomicAnd(p, ~0xff | norm_u8(x))); +} + +fn atomic_or_i8_global(p: ptr, read_write>, offset: i32, x: i8) -> i8 { + return norm_u8(atomicOr(p, norm_u8(x) << bitcast(offset * 8)) >> bitcast(offset * 8)); +} + +fn atomic_or_i8_shared(p: ptr>, x: i8) -> i8 { + return norm_u8(atomicOr(p, norm_u8(x))); +} + +fn atomic_xor_i8_global(p: ptr, read_write>, offset: i32, x: i8) -> i8 { + return norm_u8(atomicXor(p, norm_u8(x) << bitcast(offset * 8)) >> bitcast(offset * 8)); +} + +fn atomic_xor_i8_shared(p: ptr>, x: i8) -> i8 { + return norm_u8(atomicXor(p, norm_u8(x))); +} + +// End of atomics.wgsl diff --git a/rts/wgsl/lmad_copy.wgsl b/rts/wgsl/lmad_copy.wgsl new file mode 100644 index 0000000000..f3f6f19d08 --- /dev/null +++ b/rts/wgsl/lmad_copy.wgsl @@ -0,0 +1,83 @@ +override block_size_x: i32 = 1; +override block_size_y: i32 = 1; +override block_size_z: i32 = 1; + +struct CopyParameters { + dst_offset: i64, + src_offset: i64, + n: i64, + r: i32, + shape0: i64, dst_stride0: i64, src_stride0: i64, + shape1: i64, dst_stride1: i64, src_stride1: i64, + shape2: i64, dst_stride2: i64, src_stride2: i64, + shape3: i64, dst_stride3: i64, src_stride3: i64, + shape4: i64, dst_stride4: i64, src_stride4: i64, + shape5: i64, dst_stride5: i64, src_stride5: i64, + shape6: i64, dst_stride6: i64, src_stride6: i64, + shape7: i64, dst_stride7: i64, src_stride7: i64 +} + +@group(0) @binding(0) var args: CopyParameters; +@group(0) @binding(1) var dst_mem: array; +@group(0) @binding(2) var src_mem: array; +@compute @workgroup_size(block_size_x, block_size_y, block_size_z) +fn lmad_copy_NAME(@builtin(global_invocation_id) global_id: vec3) { + var remainder: i32 = i32(global_id.x); + var dst_offset: i32 = args.dst_offset.x; + var src_offset: i32 = args.src_offset.x; + + if (i32(global_id.x) >= args.n.x) { + return; + } + + if (args.r > 0) { + let i: i32 = remainder % args.shape0.x; + dst_offset += i * args.dst_stride0.x; + src_offset += i * args.src_stride0.x; + remainder /= args.shape0.x; + } + if (args.r > 1) { + let i: i32 = remainder % args.shape1.x; + dst_offset += i * args.dst_stride1.x; + src_offset += i * args.src_stride1.x; + remainder /= args.shape1.x; + } + if (args.r > 2) { + let i: i32 = remainder % args.shape2.x; + dst_offset += i * args.dst_stride2.x; + src_offset += i * args.src_stride2.x; + remainder /= args.shape2.x; + } + if (args.r > 3) { + let i: i32 = remainder % args.shape3.x; + dst_offset += i * args.dst_stride3.x; + src_offset += i * args.src_stride3.x; + remainder /= args.shape3.x; + } + if (args.r > 4) { + let i: i32 = remainder % args.shape4.x; + dst_offset += i * args.dst_stride4.x; + src_offset += i * args.src_stride4.x; + remainder /= args.shape4.x; + } + if (args.r > 5) { + let i: i32 = remainder % args.shape5.x; + dst_offset += i * args.dst_stride5.x; + src_offset += i * args.src_stride5.x; + remainder /= args.shape5.x; + } + if (args.r > 6) { + let i: i32 = remainder % args.shape6.x; + dst_offset += i * args.dst_stride6.x; + src_offset += i * args.src_stride6.x; + remainder /= args.shape6.x; + } + if (args.r > 7) { + let i: i32 = remainder % args.shape7.x; + dst_offset += i * args.dst_stride7.x; + src_offset += i * args.src_stride7.x; + remainder /= args.shape7.x; + } + + write_ELEM_TYPE(&dst_mem, dst_offset, read_ELEM_TYPE(&src_mem, src_offset)); +} diff --git a/rts/wgsl/map_transpose.wgsl b/rts/wgsl/map_transpose.wgsl new file mode 100644 index 0000000000..f30fd0be6d --- /dev/null +++ b/rts/wgsl/map_transpose.wgsl @@ -0,0 +1,87 @@ +// Constants used for transpositions. In principle these should be configurable. +const TR_BLOCK_DIM: i32 = 16; +const TR_TILE_DIM: i32 = 32; +const TR_ELEMS_PER_THREAD: i32 = 8; + +override block_size_x: i32 = 1; +override block_size_y: i32 = 1; +override block_size_z: i32 = 1; + +struct MapTransposeParameters { + dst_offset: i64, // 0 + src_offset: i64, // 8 + num_arrays: i32, // 16 + x_elems: i32, // 20 + y_elems: i32, // 24 + mulx: i32, // 28 + muly: i32, // 32 + repeat_1: i32, // 36 + repeat_2: i32 // 40 +} + +var shared_memory_ELEM_TYPE: array; + +@group(0) @binding(0) var args: MapTransposeParameters; +@group(0) @binding(1) var dst_mem: array; +@group(0) @binding(2) var src_mem: array; +@compute @workgroup_size(block_size_x, block_size_y, block_size_z) +fn map_transpose_NAME( + @builtin(workgroup_id) group_id: vec3, // tblock_id -> unique id of a group within a dispatch + @builtin(global_invocation_id) global_id: vec3, // global_id -> unique id of a thread within a dispatch + @builtin(local_invocation_id) local_id: vec3, // local_id -> unique id of a thread within a group + @builtin(num_workgroups) num_groups: vec3 +) { + let dst_offset = args.dst_offset[0]; + let src_offset = args.src_offset[0]; + + let tblock_id_0 = i32(group_id[0]); + let global_id_0 = i32(global_id[0]); + var tblock_id_1 = i32(group_id[1]); + var global_id_1 = i32(global_id[1]); + + for (var i1 = 0; i1 <= args.repeat_1; i1++) { + var tblock_id_2 = i32(group_id[2]); + var global_id_2 = i32(global_id[2]); + + for (var i2 = 0; i2 <= args.repeat_2; i2++) { + let our_array_offset = tblock_id_2 * args.x_elems * args.y_elems; + let odata_offset = dst_offset + our_array_offset; + let idata_offset = src_offset + our_array_offset; + var x_index = i32(global_id_0); + var y_index = tblock_id_1 * TR_TILE_DIM + i32(local_id[1]); + + if (x_index < args.x_elems) { + for (var j = 0; j < TR_ELEMS_PER_THREAD; j++) { + var index_i = (y_index + j * (TR_TILE_DIM/TR_ELEMS_PER_THREAD)) * args.x_elems + x_index; + if (y_index + j * (TR_TILE_DIM/TR_ELEMS_PER_THREAD) < args.y_elems) { + let shared_offset = (i32(local_id[1]) + j * (TR_TILE_DIM/TR_ELEMS_PER_THREAD)) * (TR_TILE_DIM+1) + i32(local_id[0]); + let src_val = read_ELEM_TYPE(&src_mem, idata_offset + index_i); + shared_memory_ELEM_TYPE[shared_offset] = src_val; + } + } + } + + workgroupBarrier(); + + x_index = tblock_id_1 * TR_TILE_DIM + i32(local_id[0]); + y_index = tblock_id_0 * TR_TILE_DIM + i32(local_id[1]); + + if (x_index < args.y_elems) { + for (var j = 0; j < TR_ELEMS_PER_THREAD; j++) { + var index_out = (y_index + j * (TR_TILE_DIM/TR_ELEMS_PER_THREAD)) * args.y_elems + x_index; + if (y_index + j * (TR_TILE_DIM/TR_ELEMS_PER_THREAD) < args.x_elems) { + let shared_offset = i32(local_id[0]) * (TR_TILE_DIM+1) + i32(local_id[1]) + j * (TR_TILE_DIM/TR_ELEMS_PER_THREAD); + let src_val = ELEM_TYPE(shared_memory_ELEM_TYPE[shared_offset]); + write_ELEM_TYPE(&dst_mem, odata_offset + index_out, src_val); + } + } + } + + tblock_id_2 += i32(num_groups[2]); + global_id_2 += i32(num_groups[2]) * block_size_z; + } + + tblock_id_1 += i32(num_groups[1]); + global_id_1 += i32(num_groups[1]) * block_size_y; + } +} diff --git a/rts/wgsl/map_transpose_large.wgsl b/rts/wgsl/map_transpose_large.wgsl new file mode 100644 index 0000000000..4bbb7a44ed --- /dev/null +++ b/rts/wgsl/map_transpose_large.wgsl @@ -0,0 +1,90 @@ +// Constants used for transpositions. In principle these should be configurable. +const TR_BLOCK_DIM: i32 = 16; +const TR_TILE_DIM: i32 = 32; +const TR_ELEMS_PER_THREAD: i32 = 8; + +override block_size_x: i32 = 1; +override block_size_y: i32 = 1; +override block_size_z: i32 = 1; + +struct MapTransposeParametersLarge { + dst_offset: i64, // 0 + src_offset: i64, // 8 + num_arrays: i64, // 16 + x_elems: i64, // 24 + y_elems: i64, // 32 + mulx: i64, // 40 + muly: i64, // 48 + repeat_1: i32, // 56 + repeat_2: i32 // 60 +} + +var shared_memory_ELEM_TYPE: array; + +@group(0) @binding(0) var args: MapTransposeParametersLarge; +@group(0) @binding(1) var dst_mem: array; +@group(0) @binding(2) var src_mem: array; +@compute @workgroup_size(block_size_x, block_size_y, block_size_z) +fn map_transpose_NAME_large( + @builtin(workgroup_id) group_id: vec3, // tblock_id -> unique id of a group within a dispatch + @builtin(global_invocation_id) global_id: vec3, // global_id -> unique id of a thread within a dispatch + @builtin(local_invocation_id) local_id: vec3, // local_id -> unique id of a thread within a group + @builtin(num_workgroups) num_groups: vec3 +) { + let dst_offset = args.dst_offset[0]; + let src_offset = args.src_offset[0]; + let x_elems = args.x_elems[0]; + let y_elems = args.y_elems[0]; + + let tblock_id_0 = i32(group_id[0]); + let global_id_0 = i32(global_id[0]); + var tblock_id_1 = i32(group_id[1]); + var global_id_1 = i32(global_id[1]); + + for (var i1 = 0; i1 <= args.repeat_1; i1++) { + var tblock_id_2 = i32(group_id[2]); + var global_id_2 = i32(global_id[2]); + + for (var i2 = 0; i2 <= args.repeat_2; i2++) { + let our_array_offset = tblock_id_2 * x_elems * y_elems; + let odata_offset = dst_offset + i32(our_array_offset); + let idata_offset = src_offset + i32(our_array_offset); + + var x_index = i32(global_id_0); + var y_index = tblock_id_1 * TR_TILE_DIM + i32(local_id[1]); + + if (x_index < x_elems) { + for (var j = 0; j < TR_ELEMS_PER_THREAD; j++) { + let index_i = (y_index + j * (TR_TILE_DIM/TR_ELEMS_PER_THREAD)) * x_elems + x_index; + if (y_index + j * (TR_TILE_DIM/TR_ELEMS_PER_THREAD) < y_elems) { + let shared_offset = (i32(local_id[1]) + j * (TR_TILE_DIM/TR_ELEMS_PER_THREAD)) * (TR_TILE_DIM+1) + i32(local_id[0]); + let src_val = read_ELEM_TYPE(&src_mem, idata_offset + index_i); + shared_memory_ELEM_TYPE[shared_offset] = src_val; + } + } + } + + workgroupBarrier(); + + x_index = tblock_id_1 * TR_TILE_DIM + i32(local_id[0]); + y_index = tblock_id_0 * TR_TILE_DIM + i32(local_id[1]); + + if (x_index < y_elems) { + for (var j = 0; j < TR_ELEMS_PER_THREAD; j++) { + let index_out = (y_index + j * (TR_TILE_DIM/TR_ELEMS_PER_THREAD)) * y_elems + x_index; + if (y_index + j * (TR_TILE_DIM/TR_ELEMS_PER_THREAD) < x_elems) { + let shared_offset = i32(local_id[0]) * (TR_TILE_DIM+1) + i32(local_id[1]) + j * (TR_TILE_DIM/TR_ELEMS_PER_THREAD); + let src_val = ELEM_TYPE(shared_memory_ELEM_TYPE[shared_offset]); + write_ELEM_TYPE(&dst_mem, odata_offset + index_out, src_val); + } + } + } + + tblock_id_2 += i32(num_groups[2]); + global_id_2 += i32(num_groups[2]) * block_size_z; + } + + tblock_id_1 += i32(num_groups[1]); + global_id_1 += i32(num_groups[1]) * block_size_y; + } +} \ No newline at end of file diff --git a/rts/wgsl/map_transpose_low_height.wgsl b/rts/wgsl/map_transpose_low_height.wgsl new file mode 100644 index 0000000000..3c6ea720b1 --- /dev/null +++ b/rts/wgsl/map_transpose_low_height.wgsl @@ -0,0 +1,83 @@ +// Constants used for transpositions. In principle these should be configurable. +const TR_BLOCK_DIM: i32 = 16; +const TR_TILE_DIM: i32 = 32; +const TR_ELEMS_PER_THREAD: i32 = 8; + +override block_size_x: i32 = 1; +override block_size_y: i32 = 1; +override block_size_z: i32 = 1; + +struct MapTransposeParameters { + dst_offset: i64, // 0 + src_offset: i64, // 8 + num_arrays: i32, // 16 + x_elems: i32, // 20 + y_elems: i32, // 24 + mulx: i32, // 28 + muly: i32, // 32 + repeat_1: i32, // 36 + repeat_2: i32 // 40 +} + +var shared_memory_ELEM_TYPE: array; + +@group(0) @binding(0) var args: MapTransposeParameters; +@group(0) @binding(1) var dst_mem: array; +@group(0) @binding(2) var src_mem: array; +@compute @workgroup_size(block_size_x, block_size_y, block_size_z) +fn map_transpose_NAME_low_height( + @builtin(workgroup_id) group_id: vec3, // tblock_id -> unique id of a group within a dispatch + @builtin(global_invocation_id) global_id: vec3, // global_id -> unique id of a thread within a dispatch + @builtin(local_invocation_id) local_id: vec3, // local_id -> unique id of a thread within a group + @builtin(num_workgroups) num_groups: vec3 +) { + let dst_offset = args.dst_offset[0]; + let src_offset = args.src_offset[0]; + + let tblock_id_0 = i32(group_id[0]); + let global_id_0 = i32(global_id[0]); + var tblock_id_1 = i32(group_id[1]); + var global_id_1 = i32(global_id[1]); + + for (var i1 = 0; i1 <= args.repeat_1; i1++) { + var tblock_id_2 = i32(group_id[2]); + var global_id_2 = i32(global_id[2]); + + for (var i2 = 0; i2 <= args.repeat_2; i2++) { + let our_array_offset = tblock_id_2 * args.x_elems * args.y_elems; + let odata_offset = dst_offset + our_array_offset; + let idata_offset = src_offset + our_array_offset; + var x_index = tblock_id_0 * TR_BLOCK_DIM * args.mulx + + i32(local_id[0]) + + (i32(local_id[1]) % args.mulx) * TR_BLOCK_DIM; + var y_index = tblock_id_1 * TR_BLOCK_DIM + i32(local_id[1]) / args.mulx; + let index_in = y_index * args.x_elems + x_index; + + if (x_index < args.x_elems && y_index < args.y_elems) { + let shared_offset = i32(local_id[1]) * (TR_BLOCK_DIM + 1) + i32(local_id[0]); + let src_val = read_ELEM_TYPE(&src_mem, idata_offset + index_in); + shared_memory_ELEM_TYPE[shared_offset] = src_val; + } + + workgroupBarrier(); + + x_index = tblock_id_1 * TR_BLOCK_DIM + i32(local_id[0]) / args.mulx; + y_index = tblock_id_0 * TR_BLOCK_DIM * args.mulx + + i32(local_id[1]) + + (i32(local_id[0]) % args.mulx) * TR_BLOCK_DIM; + let index_out = y_index * args.y_elems + x_index; + + if (x_index < args.y_elems && y_index < args.x_elems) { + let shared_offset = i32(local_id[0]) * (TR_BLOCK_DIM + 1) + i32(local_id[1]); + let src_val = ELEM_TYPE(shared_memory_ELEM_TYPE[shared_offset]); + write_ELEM_TYPE(&dst_mem, odata_offset + index_out, src_val); + } + + tblock_id_2 += i32(num_groups[2]); + global_id_2 += i32(num_groups[2]) * block_size_z; + } + + tblock_id_1 += i32(num_groups[1]); + global_id_1 += i32(num_groups[1]) * block_size_y; + } +} \ No newline at end of file diff --git a/rts/wgsl/map_transpose_low_width.wgsl b/rts/wgsl/map_transpose_low_width.wgsl new file mode 100644 index 0000000000..ebaa0bb3e9 --- /dev/null +++ b/rts/wgsl/map_transpose_low_width.wgsl @@ -0,0 +1,83 @@ +// Constants used for transpositions. In principle these should be configurable. +const TR_BLOCK_DIM: i32 = 16; +const TR_TILE_DIM: i32 = 32; +const TR_ELEMS_PER_THREAD: i32 = 8; + +override block_size_x: i32 = 1; +override block_size_y: i32 = 1; +override block_size_z: i32 = 1; + +struct MapTransposeParameters { + dst_offset: i64, // 0 + src_offset: i64, // 8 + num_arrays: i32, // 16 + x_elems: i32, // 20 + y_elems: i32, // 24 + mulx: i32, // 28 + muly: i32, // 32 + repeat_1: i32, // 36 + repeat_2: i32 // 40 +} + +var shared_memory_ELEM_TYPE: array; + +@group(0) @binding(0) var args: MapTransposeParameters; +@group(0) @binding(1) var dst_mem: array; +@group(0) @binding(2) var src_mem: array; +@compute @workgroup_size(block_size_x, block_size_y, block_size_z) +fn map_transpose_NAME_low_width( + @builtin(workgroup_id) group_id: vec3, // tblock_id -> unique id of a group within a dispatch + @builtin(global_invocation_id) global_id: vec3, // global_id -> unique id of a thread within a dispatch + @builtin(local_invocation_id) local_id: vec3, // local_id -> unique id of a thread within a group + @builtin(num_workgroups) num_groups: vec3 +) { + let dst_offset = args.dst_offset[0]; + let src_offset = args.src_offset[0]; + + let tblock_id_0 = i32(group_id[0]); + let global_id_0 = i32(global_id[0]); + var tblock_id_1 = i32(group_id[1]); + var global_id_1 = i32(global_id[1]); + + for (var i1 = 0; i1 <= args.repeat_1; i1++) { + var tblock_id_2 = i32(group_id[2]); + var global_id_2 = i32(global_id[2]); + + for (var i2 = 0; i2 <= args.repeat_2; i2++) { + let our_array_offset = tblock_id_2 * args.x_elems * args.y_elems; + let odata_offset = dst_offset + our_array_offset; + let idata_offset = src_offset + our_array_offset; + var x_index = tblock_id_0 * TR_BLOCK_DIM + i32(local_id[0]) / args.muly; + var y_index = tblock_id_1 * TR_BLOCK_DIM * args.muly + + i32(local_id[1]) + + (i32(local_id[0]) % args.muly) * TR_BLOCK_DIM; + let index_in = y_index * args.x_elems + x_index; + + if (x_index < args.x_elems && y_index < args.y_elems) { + let shared_offset = i32(local_id[1]) * (TR_BLOCK_DIM + 1) + i32(local_id[0]); + let src_val = read_ELEM_TYPE(&src_mem, idata_offset + index_in); + shared_memory_ELEM_TYPE[shared_offset] = src_val; + } + + workgroupBarrier(); + + x_index = tblock_id_1 * TR_BLOCK_DIM * args.muly + + i32(local_id[0]) + + (i32(local_id[1]) % args.muly) * TR_BLOCK_DIM; + y_index = tblock_id_0 * TR_BLOCK_DIM + i32(local_id[1]) / args.muly; + let index_out = y_index * args.y_elems + x_index; + + if (x_index < args.y_elems && y_index < args.x_elems) { + let shared_offset = i32(local_id[0]) * (TR_BLOCK_DIM + 1) + i32(local_id[1]); + let src_val = ELEM_TYPE(shared_memory_ELEM_TYPE[shared_offset]); + write_ELEM_TYPE(&dst_mem, odata_offset + index_out, src_val); + } + + tblock_id_2 += i32(num_groups[2]); + global_id_2 += i32(num_groups[2]) * block_size_z; + } + + tblock_id_1 += i32(num_groups[1]); + global_id_1 += i32(num_groups[1]) * block_size_y; + } +} \ No newline at end of file diff --git a/rts/wgsl/map_transpose_small.wgsl b/rts/wgsl/map_transpose_small.wgsl new file mode 100644 index 0000000000..6fa497e4df --- /dev/null +++ b/rts/wgsl/map_transpose_small.wgsl @@ -0,0 +1,49 @@ +// Constants used for transpositions. In principle these should be configurable. +const TR_BLOCK_DIM: i32 = 16; +const TR_TILE_DIM: i32 = 32; +const TR_ELEMS_PER_THREAD: i32 = 8; + +override block_size_x: i32 = 1; +override block_size_y: i32 = 1; +override block_size_z: i32 = 1; + +struct MapTransposeParameters { + dst_offset: i64, // 0 + src_offset: i64, // 8 + num_arrays: i32, // 16 + x_elems: i32, // 20 + y_elems: i32, // 24 + mulx: i32, // 28 + muly: i32, // 32 + repeat_1: i32, // 36 + repeat_2: i32 // 40 +} + +@group(0) @binding(0) var args: MapTransposeParameters; +@group(0) @binding(1) var dst_mem: array; +@group(0) @binding(2) var src_mem: array; +@compute @workgroup_size(block_size_x, block_size_y, block_size_z) +fn map_transpose_NAME_small( + @builtin(workgroup_id) group_id: vec3, // tblock_id -> unique id of a group within a dispatch + @builtin(global_invocation_id) global_id: vec3, // global_id -> unique id of a thread within a dispatch + @builtin(local_invocation_id) local_id: vec3, // local_id -> unique id of a thread within a group + @builtin(num_workgroups) num_groups: vec3 +) { + let dst_offset = args.dst_offset[0]; + let src_offset = args.src_offset[0]; + + let global_id_0 = i32(global_id[0]); + + let our_array_offset = global_id_0 / (args.y_elems * args.x_elems) * args.y_elems * args.x_elems; + let x_index = (global_id_0 % (args.y_elems * args.x_elems)) / args.y_elems; + let y_index = global_id_0 % args.y_elems; + + let odata_offset = dst_offset + our_array_offset; + let idata_offset = src_offset + our_array_offset; + let index_in = y_index * args.x_elems + x_index; + let index_out = x_index * args.y_elems + y_index; + + if (global_id_0 < args.x_elems * args.y_elems * args.num_arrays) { + write_ELEM_TYPE(&dst_mem, odata_offset + index_out, read_ELEM_TYPE(&src_mem, idata_offset + index_in)); + } +} diff --git a/rts/wgsl/scalar.wgsl b/rts/wgsl/scalar.wgsl new file mode 100644 index 0000000000..481acc4121 --- /dev/null +++ b/rts/wgsl/scalar.wgsl @@ -0,0 +1,200 @@ +fn log_and(a: bool, b: bool) -> bool { return a && b; } +fn log_or(a: bool, b: bool) -> bool { return a || b; } +fn llt(a: bool, b: bool) -> bool { return a == false && b == true; } +fn lle(a: bool, b: bool) -> bool { return a == b || llt(a, b); } + +fn futrts_sqrt32(a: f32) -> f32 { return sqrt(a); } +fn futrts_sqrt16(a: f16) -> f16 { return sqrt(a); } + +fn futrts_rsqrt32(a: f32) -> f32 { return inverseSqrt(a); } +fn futrts_rsqrt16(a: f16) -> f16 { return inverseSqrt(a); } + +//fn futrts_cbrt32(a: f32) -> f32 { return ???; } +//fn futrts_cbrt16(a: f16) -> f16 { return ???; } + +fn futrts_log32(a: f32) -> f32 { return log(a); } +fn futrts_log16(a: f16) -> f16 { return log(a); } + +fn futrts_log10_32(a: f32) -> f32 { return log(a) / log(10); } +fn futrts_log10_16(a: f16) -> f16 { return log(a) / log(10); } + +fn futrts_log1p_32(a: f32) -> f32 { return log(1.0 + a); } +fn futrts_log1p_16(a: f16) -> f16 { return log(1.0 + a); } + +fn futrts_log2_32(a: f32) -> f32 { return log2(a); } +fn futrts_log2_16(a: f16) -> f16 { return log2(a); } + +fn futrts_exp32(a: f32) -> f32 { return exp(a); } +fn futrts_exp16(a: f16) -> f16 { return exp(a); } + +fn futrts_sin32(a: f32) -> f32 { return sin(a); } +fn futrts_sin16(a: f16) -> f16 { return sin(a); } + +fn futrts_sinpi32(a: f32) -> f32 { return sin(a * 3.14159265358979323846); } +fn futrts_sinpi16(a: f16) -> f16 { return sin(a * 3.14159265358979323846); } + +fn futrts_sinh32(a: f32) -> f32 { return sinh(a); } +fn futrts_sinh16(a: f16) -> f16 { return sinh(a); } + +fn futrts_cos32(a: f32) -> f32 { return cos(a); } +fn futrts_cos16(a: f16) -> f16 { return cos(a); } + +fn futrts_cospi32(a: f32) -> f32 { return cos(a * 3.14159265358979323846); } +fn futrts_cospi16(a: f16) -> f16 { return cos(a * 3.14159265358979323846); } + +fn futrts_cosh32(a: f32) -> f32 { return cosh(a); } +fn futrts_cosh16(a: f16) -> f16 { return cosh(a); } + +fn futrts_tan32(a: f32) -> f32 { return tan(a); } +fn futrts_tan16(a: f16) -> f16 { return tan(a); } + +fn futrts_tanpi32(a: f32) -> f32 { return tan(a * 3.14159265358979323846); } +fn futrts_tanpi16(a: f16) -> f16 { return tan(a * 3.14159265358979323846); } + +fn futrts_tanh32(a: f32) -> f32 { return tanh(a); } +fn futrts_tanh16(a: f16) -> f16 { return tanh(a); } + +fn futrts_asin32(a: f32) -> f32 { return asin(a); } +fn futrts_asin16(a: f16) -> f16 { return asin(a); } + +fn futrts_asinpi32(a: f32) -> f32 { return asin(a) / 3.14159265358979323846; } +fn futrts_asinpi16(a: f16) -> f16 { return asin(a) / 3.14159265358979323846; } + +fn futrts_asinh32(a: f32) -> f32 { return asinh(a); } +fn futrts_asinh16(a: f16) -> f16 { return asinh(a); } + +fn futrts_acos32(a: f32) -> f32 { return acos(a); } +fn futrts_acos16(a: f16) -> f16 { return acos(a); } + +fn futrts_acospi32(a: f32) -> f32 { return acos(a) / 3.14159265358979323846; } +fn futrts_acospi16(a: f16) -> f16 { return acos(a) / 3.14159265358979323846; } + +fn futrts_acosh32(a: f32) -> f32 { return acosh(a); } +fn futrts_acosh16(a: f16) -> f16 { return acosh(a); } + +fn futrts_atan32(a: f32) -> f32 { return atan(a); } +fn futrts_atan16(a: f16) -> f16 { return atan(a); } + +fn futrts_atanpi32(a: f32) -> f32 { return atan(a) / 3.14159265358979323846; } +fn futrts_atanpi16(a: f16) -> f16 { return atan(a) / 3.14159265358979323846; } + +fn futrts_atanh32(a: f32) -> f32 { return atanh(a); } +fn futrts_atanh16(a: f16) -> f16 { return atanh(a); } + +fn futrts_round_32(a: f32) -> f32 { return round(a); } +fn futrts_round_16(a: f16) -> f16 { return round(a); } + +fn futrts_ceil32(a: f32) -> f32 { return ceil(a); } +fn futrts_ceil16(a: f16) -> f16 { return ceil(a); } + +fn futrts_floor32(a: f32) -> f32 { return floor(a); } +fn futrts_floor16(a: f16) -> f16 { return floor(a); } + +fn futrts_ldexp32(a: f32, b: i32) -> f32 { return ldexp(a, b); } +fn futrts_ldexp16(a: f16, b: i32) -> f16 { return ldexp(a, b); } + +fn futrts_atan2_32(a: f32, b: f32) -> f32 { if (a == 0 && b == 0) { return 0; } return atan2(a, b); } +fn futrts_atan2_16(a: f16, b: f16) -> f16 { if (a == 0 && b == 0) { return 0; } return atan2(a, b); } + +fn futrts_atan2pi_32(a: f32, b: f32) -> f32 { return futrts_atan2_32(a, b) / 3.14159265358979323846; } +fn futrts_atan2pi_16(a: f16, b: f16) -> f16 { return futrts_atan2_16(a, b) / 3.14159265358979323846; } + +fn futrts_to_bits16(a: f16) -> i16 { return bitcast(vec2(a, 0.0)); } +fn futrts_from_bits16(a: i16) -> f16 { return bitcast>(a)[0]; } + +fn futrts_to_bits32(a: f32) -> i32 { return bitcast(a); } +fn futrts_from_bits32(a: i32) -> f32 { return bitcast(a); } + +fn futrts_round32(x: f32) -> f32 { return round(x); } +fn futrts_round16(x: f16) -> f16 { return round(x); } + +fn futrts_lerp32(a: f32, b: f32, t: f32) -> f32 { return mix(a, b, t); } +fn futrts_lerp16(a: f16, b: f16, t: f16) -> f16 { return mix(a, b, t); } + +fn futrts_mad32(a: f32, b: f32, c: f32) -> f32 { return a * b + c; } +fn futrts_mad16(a: f16, b: f16, c: f16) -> f16 { return a * b + c; } + +fn futrts_fma32(a: f32, b: f32, c: f32) -> f32 { return fma(a, b, c); } +fn futrts_fma16(a: f16, b: f16, c: f16) -> f16 { return fma(a, b, c); } + +fn futrts_popc64(a: i64) -> i32 { return countOneBits(a.x) + countOneBits(a.y); } +fn futrts_popc32(a: i32) -> i32 { return countOneBits(a); } +fn futrts_popc16(a: i16) -> i32 { return countOneBits(a & 0xffff); } +fn futrts_popc8(a: i8) -> i32 { return countOneBits(a & 0xff); } + +// TODO: mul_hi32 and 64 cannot currently be implemented properly. +fn futrts_umul_hi8(a: i8, b: i8) -> i8 { return norm_u8((norm_u8(a) * norm_u8(b)) >> 8); } +fn futrts_umul_hi16(a: i16, b: i16) -> i16 { return norm_u16((norm_u16(a) * norm_u16(b)) >> 16); } +fn futrts_umul_hi32(a: i32, b: i32) -> i32 { return bitcast(bitcast(a) * bitcast(b)); } +fn futrts_umul_hi64(a: i64, b: i64) -> i64 { return i64(mul_i64(a, b)[1]); } +fn futrts_smul_hi8(a: i8, b: i8) -> i8 { return norm_i8((a * b) >> 8); } +fn futrts_smul_hi16(a: i16, b: i16) -> i16 { return norm_i16((a * b) >> 16); } +fn futrts_smul_hi32(a: i32, b: i32) -> i32 { return a * b; } +fn futrts_smul_hi64(a: i64, b: i64) -> i64 { return i64(mul_i64(a, b)[1]); } + +fn futrts_umad_hi8(a: i8, b: i8, c: i8) -> i8 { return norm_u8(futrts_umul_hi8(a, b) + norm_u8(c)); } +fn futrts_umad_hi16(a: i16, b: i16, c: i16) -> i16 { return norm_u16(futrts_umul_hi16(a, b) + norm_u16(c)); } +fn futrts_umad_hi32(a: i32, b: i32, c: i32) -> i32 { return bitcast(bitcast(futrts_umul_hi32(a, b)) + bitcast(c)); } +fn futrts_umad_hi64(a: i64, b: i64, c: i64) -> i64 { return add_i64(futrts_umul_hi64(a, b), c); } +fn futrts_smad_hi8(a: i8, b: i8, c: i8) -> i8 { return norm_i8(futrts_smul_hi8(a, b) + c); } +fn futrts_smad_hi16(a: i16, b: i16, c: i16) -> i16 { return norm_i16(futrts_smul_hi16(a, b) + c); } +fn futrts_smad_hi32(a: i32, b: i32, c: i32) -> i32 { return futrts_smul_hi32(a, b) + c; } +fn futrts_smad_hi64(a: i64, b: i64, c: i64) -> i64 { return add_i64(futrts_smul_hi64(a, b), c); } + +fn futrts_clzz8(x: i8) -> i32 { return countLeadingZeros(x & 0xff) - 24; } +fn futrts_clzz16(x: i16) -> i32 { return countLeadingZeros(x & 0xffff) - 16; } +fn futrts_clzz32(x: i32) -> i32 { return countLeadingZeros(x); } +fn futrts_clzz64(x: i64) -> i32 { + if (x[1] == 0) { + return countLeadingZeros(x[0]) + 32; + } + else { + return countLeadingZeros(x[1]); + } +} + +fn futrts_ctzz8(x: i8) -> i32 { return min(8, countTrailingZeros(x & 0xff)); } +fn futrts_ctzz16(x: i16) -> i32 { return min(16, countTrailingZeros(x & 0xffff)); } +fn futrts_ctzz32(x: i32) -> i32 { return countTrailingZeros(x); } +fn futrts_ctzz64(x: i64) -> i32 { + if (x[0] == 0) { + return countTrailingZeros(x[1]) + 32; + } + else { + return countTrailingZeros(x[0]); + } +} + +fn futrts_isnan32(x: f32) -> bool { + let exponent = (bitcast(x) >> 23) & 0xFF; + let mantissa = bitcast(x) & 0x7fffff; + + // If the exponent field is all 1s, and mantissa is nonzero, x is a NaN. + return (exponent == 0xff && mantissa != 0); +} + +fn futrts_isinf32(x: f32) -> bool { + let exponent = (bitcast(x) >> 23) & 0xFF; + let mantissa = bitcast(x) & 0x7fffff; + + // If the exponent field is all 1s, and mantissa is zero, x is an infinity. + return (exponent == 0xff && mantissa == 0); +} + +fn futrts_isnan16(x: f16) -> bool { + let bits = bitcast(vec2(x, 0.0)); + let exponent = (bits >> 10) & 0x1F; + let mantissa = bits & 0x3ff; + + // If the exponent field is all 1s, and mantissa is nonzero, x is a NaN. + return (exponent == 0x1f && mantissa != 0); +} + +fn futrts_isinf16(x: f16) -> bool { + let bits = bitcast(vec2(x, 0.0)); + let exponent = (bits >> 10) & 0x1F; + let mantissa = bits & 0x3ff; + + // If the exponent field is all 1s, and mantissa is zero, x is an infinity. + return (exponent == 0x1f && mantissa == 0); +} \ No newline at end of file diff --git a/rts/wgsl/scalar16.wgsl b/rts/wgsl/scalar16.wgsl new file mode 100644 index 0000000000..3d65c5fa85 --- /dev/null +++ b/rts/wgsl/scalar16.wgsl @@ -0,0 +1,185 @@ +// Start of scalar16.wgsl + +alias i16 = i32; + +fn norm_i16(a: i16) -> i32 { + if (a & 0x8000) != 0 { return a | bitcast(0xffff0000u); } + return a & 0x0000ffff; +} + +fn norm_u16(a: i16) -> i32 { + return a & 0x0000ffff; +} + +fn read_i16(buffer: ptr>, read_write>, i: i32) -> i16 { + let elem_idx = i / 2; + let idx_in_elem = i % 2; + + let v = atomicLoad(&((*buffer)[elem_idx])); + return norm_i16(v >> bitcast(idx_in_elem * 16)); +} + +fn write_i16(buffer: ptr>, read_write>, + i: i32, + val: i16 +) { + let elem_idx = i / 2; + let idx_in_elem = i % 2; + + let shift_amt = bitcast(idx_in_elem * 16); + + let mask = 0xffff << shift_amt; + let shifted_val = (val << shift_amt) & mask; + + // First zero out the previous value using the inverted mask. + atomicAnd(&((*buffer)[elem_idx]), ~mask); + // And then write the new value. + atomicOr(&((*buffer)[elem_idx]), shifted_val); +} + +fn add_i16(a: i16, b: i16) -> i16 { + return norm_i16(a + b); +} + +fn neg_i16(a: i16) -> i16 { + return add_i16(~a, 1); +} + +fn sub_i16(a: i16, b: i16) -> i16 { + return add_i16(a, neg_i16(b)); +} + +fn mul_i16(a: i16, b: i16) -> i16 { + return norm_i16(a * b); +} + +fn udiv_i16(a: i16, b: i16) -> i16 { + return norm_i16(udiv_i32(norm_u16(a), norm_u16(b))); +} + +fn udiv_up_i16(a: i16, b: i16) -> i16 { + return norm_i16(udiv_up_i32(norm_u16(a), norm_u16(b))); +} + +fn sdiv_i16(a: i16, b: i16) -> i16 { + return sdiv_i32(a, b); +} + +fn sdiv_up_i16(a: i16, b: i16) -> i16 { + return sdiv_up_i32(a, b); +} + +fn umod_i16(a: i16, b: i16) -> i16 { + return norm_i16(umod_i32(norm_u16(a), norm_u16(b))); +} + +fn smod_i16(a: i16, b: i16) -> i16 { + return smod_i32(a, b); +} + +fn umin_i16(a: i16, b: i16) -> i16 { + return umin_i32(a, b); +} + +fn umax_i16(a: i16, b: i16) -> i16 { + return umax_i32(a, b); +} + +fn shl_i16(a: i16, b: i16) -> i16 { + return a << bitcast(b); +} + +fn lshr_i16(a: i16, b: i16) -> i16 { + return bitcast(bitcast(a) >> bitcast(b)); +} + +fn ashr_i16(a: i16, b: i16) -> i16 { + return a >> bitcast(b); +} + +fn pow_i16(a_p: i16, b: i16) -> i16 { + var a = a_p; + var res: i16 = 1; + var rem: i16 = b; + + while rem != 0 { + if (rem & 1) != 0 { + res = mul_i16(res, a); + } + rem = ashr_i16(rem, 1); + a = mul_i16(a, a); + } + + return res; +} + +fn zext_i16_i32(a: i16) -> i32 { + return a & 0xffff; +} + +// Our representation is already a sign-extended i32. +fn sext_i16_i32(a: i16) -> i32 { + return a; +} + +fn zext_i16_i64(a: i16) -> i64 { + return i64(zext_i16_i32(a), 0); +} + +fn sext_i16_i64(a: i16) -> i64 { + return sext_i32_i64(a); +} + +fn trunc_i32_i16(a: i32) -> i16 { + return a & 0xffff; +} + +fn trunc_i64_i16(a: i64) -> i16 { + return trunc_i32_i16(a.x); +} + +fn f16_to_u16(a: f16) -> i16 { + return norm_u16(i32(u32(a))); +} + +fn f32_to_u16(a: f32) -> i16 { + return norm_u16(i32(u32(a))); +} + +fn f16_to_i16(a: f16) -> i16 { + return norm_i16(i32(a)); +} + +fn f32_to_i16(a: f32) -> i16 { + return norm_i16(i32(a)); +} + +fn u16_to_f16(a: i16) -> f16 { + return f16(bitcast(a)); +} + +fn u16_to_f32(a: i16) -> f32 { + return f32(bitcast(a)); +} + +fn bool_to_i16(a: bool) -> i16 { + if a { return 1; } else { return 0; } +} + +fn f16_inf_helper() -> u32 { return 0x7c00u; } +fn f16_neg_inf_helper() -> u32 { return 0xfc00u; } +fn f16_nan_helper() -> u32 { return 0xffffu; } + +fn f16_inf() -> f16 { + return bitcast>(f16_inf_helper())[0]; +} + +fn f16_neg_inf() -> f16 { + return bitcast>(f16_neg_inf_helper())[0]; +} + +fn f16_nan() -> f16 { + return bitcast>(f16_nan_helper())[0]; +} + +// End of scalar16.wgsl diff --git a/rts/wgsl/scalar32.wgsl b/rts/wgsl/scalar32.wgsl new file mode 100644 index 0000000000..5dcd66c321 --- /dev/null +++ b/rts/wgsl/scalar32.wgsl @@ -0,0 +1,127 @@ +// Start of scalar32.wgsl + +fn read_i32(buffer: ptr, read_write>, i: i32) -> i32 { + return (*buffer)[i]; +} + +fn write_i32(buffer: ptr, read_write>, i: i32, val: i32) { + (*buffer)[i] = val; +} + +fn neg_i32(a: i32) -> i32 { + return -a; +} + +fn udiv_i32(a: i32, b: i32) -> i32 { + return bitcast(bitcast(a) / bitcast(b)); +} + +fn udiv_up_i32(a_s: i32, b_s: i32) -> i32 { + let a = bitcast(a_s); + let b = bitcast(b_s); + return bitcast((a + b - 1) / b); +} + +fn sdiv_i32(a: i32, b: i32) -> i32 { + let q = a / b; + let r = a % b; + if (r != 0 && ((r < 0) != (b < 0))) { return q - 1; } + return q; +} + +fn sdiv_up_i32(a: i32, b: i32) -> i32 { + return sdiv_i32(a + b - 1, b); +} + +fn umod_i32(a: i32, b: i32) -> i32 { + return bitcast(bitcast(a) % bitcast(b)); +} + +fn smod_i32(a: i32, b: i32) -> i32 { + let r = a % b; + if (r == 0 || (a > 0 && b > 0) || (a < 0 && b < 0)) { return r; } + return r + b; +} + +fn umin_i32(a: i32, b: i32) -> i32 { + return bitcast(min(bitcast(a), bitcast(b))); +} + +fn umax_i32(a: i32, b: i32) -> i32 { + return bitcast(max(bitcast(a), bitcast(b))); +} + +fn shl_i32(a: i32, b: i32) -> i32 { + return a << bitcast(b); +} + +fn lshr_i32(a: i32, b: i32) -> i32 { + return bitcast(bitcast(a) >> bitcast(b)); +} + +fn ashr_i32(a: i32, b: i32) -> i32 { + return a >> bitcast(b); +} + +fn pow_i32(a_p: i32, b: i32) -> i32 { + var a = a_p; + var res: i32 = 1; + var rem: i32 = b; + + while rem != 0 { + if (rem & 1) != 0 { + res = res * a; + } + rem = rem >> 1; + a = a * a; + } + + return res; +} + +fn ult_i32(a: i32, b: i32) -> bool { + return bitcast(a) < bitcast(b); +} + +fn ule_i32(a: i32, b: i32) -> bool { + return bitcast(a) <= bitcast(b); +} + +fn usignum_i32(a: i32) -> i32 { + if a == 0 { return 0; } + return 1; +} + +fn f16_to_u32(a: f16) -> i32 { + return bitcast(u32(a)); +} + +fn f32_to_u32(a: f32) -> i32 { + return bitcast(u32(a)); +} + +fn u32_to_f16(a: i32) -> f16 { + return f16(bitcast(a)); +} + +fn u32_to_f32(a: i32) -> f32 { + return f32(bitcast(a)); +} + +fn f32_inf_helper() -> u32 { return 0x7f800000u; } +fn f32_neg_inf_helper() -> u32 { return 0xff800000u; } +fn f32_nan_helper() -> u32 { return 0xffffffffu; } + +fn f32_inf() -> f32 { + return bitcast(f32_inf_helper()); +} + +fn f32_neg_inf() -> f32 { + return bitcast(f32_neg_inf_helper()); +} + +fn f32_nan() -> f32 { + return bitcast(f32_nan_helper()); +} + +// End of scalar32.wgsl diff --git a/rts/wgsl/scalar64.wgsl b/rts/wgsl/scalar64.wgsl new file mode 100644 index 0000000000..cdbbad8c91 --- /dev/null +++ b/rts/wgsl/scalar64.wgsl @@ -0,0 +1,255 @@ +// Start of scalar64.wgsl + +alias i64 = vec2; // (low, high) + +const zero_i64: i64 = i64(0, 0); +const one_i64: i64 = i64(1, 0); + +fn read_i64(buffer: ptr, read_write>, i: i32) -> i64 { + return (*buffer)[i]; +} + +fn write_i64(buffer: ptr, read_write>, i: i32, val: i64) { + (*buffer)[i] = val; +} + +fn add_i64(a: i64, b: i64) -> i64 { + // return bitcast(add_u64(bitcast(a), bitcast(b))); + var r = a + b; + if (bitcast(r.x) < bitcast(a.x)) { r.y += 1; } + return r; +} + +fn neg_i64(a: i64) -> i64 { + return add_i64(~a, one_i64); +} + +fn sub_i64(a: i64, b: i64) -> i64 { + return add_i64(a, neg_i64(b)); +} + +fn mul_u32_full(a32: u32, b32: u32) -> i64 { + let a = vec2(a32 & 0xFFFF, a32 >> 16); + let b = vec2(b32 & 0xFFFF, b32 >> 16); + let ll = a.x * b.x; var hh = a.y * b.y; + let lh = a.x * b.y; var hl = a.y * b.x; + let mid = hl + (ll >> 16) + (lh & 0xFFFF); + return bitcast(vec2( + (mid << 16) | (ll & 0xFFFF), + hh + (mid >> 16) + (lh >> 16) + )); +} + +fn mul_i64(a: i64, b: i64) -> i64 { + return add_i64( + mul_u32_full(bitcast(a.x), bitcast(b.x)), + i64(0, a.x * b.y + b.x * a.y) + ); +} + +// TODO: i64 division-related operations are not yet implemented properly. +// As a stopgap to at least deal with 64-bit size calculations (where the values +// are always small enough for this to be fine on WebGPU), they just truncate to +// i32 values. +fn udiv_i64(a: i64, b: i64) -> i64 { + return i64(udiv_i32(a.x, b.y), 0); +} + +fn udiv_up_i64(a: i64, b: i64) -> i64 { + return i64(udiv_up_i32(a.x, b.x), 0); +} + +fn sdiv_i64(a: i64, b: i64) -> i64 { + return sext_i32_i64(sdiv_i32(a.x, b.x)); +} + +fn sdiv_up_i64(a: i64, b: i64) -> i64 { + return sext_i32_i64(sdiv_up_i32(a.x, b.x)); +} + +fn umod_i64(a: i64, b: i64) -> i64 { + return i64(umod_i32(a.x, b.x), 0); +} + +fn smod_i64(a: i64, b: i64) -> i64 { + return sext_i32_i64(smod_i32(a.x, b.x)); +} + +fn squot_i64(a: i64, b: i64) -> i64 { + return sext_i32_i64(a.x / b.x); +} + +fn srem_i64(a: i64, b: i64) -> i64 { + return sext_i32_i64(a.x % b.x); +} + +fn smin_i64(a: i64, b: i64) -> i64 { + if slt_i64(a, b) { return a; } + return b; +} + +fn umin_i64(a: i64, b: i64) -> i64 { + if ult_i64(a, b) { return a; } + return b; +} + +fn smax_i64(a: i64, b: i64) -> i64 { + if slt_i64(a, b) { return b; } + return a; +} + +fn umax_i64(a: i64, b: i64) -> i64 { + if ult_i64(a, b) { return b; } + return a; +} + +fn shl_i64(a: i64, b_full: i64) -> i64 { + // Shifting by more than 64 and by negative amounts is undefined, so we can + // assume b.y is 0 and b.x >= 0. + let b: u32 = bitcast(b_full.x); + + if b == 0 { return a; } + if b >= 32 { return i64(0, a.x << (b - 32)); } + + let shifted_over = bitcast(bitcast(a.x) >> (32 - b)); + return i64(a.x << b, (a.y << b) | shifted_over); +} + +fn lshr_i64(a: i64, b_full: i64) -> i64 { + // Shifting by more than 64 and by negative amounts is undefined, so we can + // assume b.y is 0 and b.x >= 0. + let b: i32 = b_full.x; + + if b == 0 { return a; } + if b >= 32 { return i64(lshr_i32(a.y, b - 32), 0); } + + let shifted_over = a.y << bitcast(32 - b); + return i64(lshr_i32(a.x, b) | shifted_over, lshr_i32(a.y, b)); +} + +fn ashr_i64(a: i64, b_full: i64) -> i64 { + // Shifting by more than 64 and by negative amounts is undefined, so we can + // assume b.y is 0 and b.x >= 0. + let b: u32 = bitcast(b_full.x); + + if b == 0 { return a; } + if b >= 32 { + var high: i32; + if a.y < 0 { high = -1; } else { high = 0; } + return i64(a.y >> (b - 32), high); + } + + let shifted_over = a.y << (32 - b); + return i64(lshr_i32(a.x, bitcast(b)) | shifted_over, a.y >> b); +} + +fn pow_i64(a_p: i64, b: i64) -> i64 { + var a = a_p; + var res: i64 = one_i64; + var rem: i64 = b; + + while !eq_i64(rem, zero_i64) { + if !eq_i64(rem & one_i64, zero_i64) { + res = mul_i64(res, a); + } + rem = ashr_i64(rem, one_i64); + a = mul_i64(a, a); + } + + return res; +} + +fn eq_i64(a: i64, b: i64) -> bool { + return all(a == b); +} + +fn ult_i64(a_s: i64, b_s: i64) -> bool { + let a = bitcast>(a_s); + let b = bitcast>(b_s); + return a.y < b.y || (a.y == b.y && a.x < b.x); +} + +fn ule_i64(a_s: i64, b_s: i64) -> bool { + let a = bitcast>(a_s); + let b = bitcast>(b_s); + return a.y < b.y || (a.y == b.y && a.x <= b.x); +} + +fn slt_i64(a: i64, b: i64) -> bool { + return a.y < b.y || (a.y == b.y && a.x < b.x); +} + +fn sle_i64(a: i64, b: i64) -> bool { + return a.y < b.y || (a.y == b.y && a.x <= b.x); +} + +fn abs_i64(a: i64) -> i64 { + if slt_i64(a, zero_i64) { return neg_i64(a); } + return a; +} + +fn ssignum_i64(a: i64) -> i64 { + if slt_i64(a, zero_i64) { return i64(-1, -1); } + if all(a == zero_i64) { return i64(0, 0); } + return i64(1, 0); +} + +fn usignum_i64(a: i64) -> i64 { + if all(a == zero_i64) { return i64(0, 0); } + return i64(1, 0); +} + +fn zext_i32_i64(a: i32) -> i64 { + return i64(a, 0); +} + +fn sext_i32_i64(a: i32) -> i64 { + if (a < 0) { return i64(a, -1); } + return i64(a, 0); +} + +fn trunc_i64_i32(a: i64) -> i32 { + return a.x; +} + +fn bool_to_i64(a: bool) -> i64 { + if a { return one_i64; } else { return zero_i64; } +} + +fn i64_to_bool(a: i64) -> bool { + if eq_i64(a, zero_i64) { return false; } else { return true; } +} + +// TODO: This is not accurate to a real i64->f16 conversion, but hopefully good +// enough for now. +fn i64_to_f16(a: i64) -> f16 { + if (a.y == -1) { + if (a.x > 0) { return -f16(a.x); } + else { return f16(a.x); } + } + // Just ignoring the high bits. They will be out of range of f16 anyway, and + // since WGSL does not even spec that infinity works as expected, I'm not sure + // what else to do here. + return f16(a.x); +} + +// TODO: This is not accurate to a real i64->f32 conversion, but hopefully good +// enough for now. +fn i64_to_f32(a: i64) -> f32 { + if (a.y == -1) { + if (a.x > 0) { return -f32(a.x); } + else { return f32(a.x); } + } + return f32(bitcast(a.x)) + f32(a.y) * 2e32f; +} + +fn u64_to_f16(a: i64) -> f16 { + // See i64_to_f16 regarding the high bits. + return f16(bitcast(a.x)); +} + +fn u64_to_f32(a: i64) -> f32 { + return f32(bitcast(a.x)) + f32(bitcast(a.y)) * 2e32f; +} + +// End of scalar64.wgsl diff --git a/rts/wgsl/scalar8.wgsl b/rts/wgsl/scalar8.wgsl new file mode 100644 index 0000000000..305e539425 --- /dev/null +++ b/rts/wgsl/scalar8.wgsl @@ -0,0 +1,195 @@ +// Start of scalar8.wgsl + +alias i8 = i32; + +fn norm_i8(a: i8) -> i32 { + if (a & 0x80) != 0 { return a | bitcast(0xffffff00u); } + return a & 0x000000ff; +} + +fn norm_u8(a: i8) -> i32 { + return a & 0x000000ff; +} + +fn read_i8(buffer: ptr>, read_write>, i: i32) -> i8 { + let elem_idx = i / 4; + let idx_in_elem = i % 4; + + let v = atomicLoad(&((*buffer)[elem_idx])); + return norm_i8(v >> bitcast(idx_in_elem * 8)); +} + +fn read_bool(buffer: ptr>, read_write>, + i: i32 +) -> bool { + return read_i8(buffer, i) != 0; +} + +fn write_i8(buffer: ptr>, read_write>, + i: i32, + val: i8 +) { + let elem_idx = i / 4; + let idx_in_elem = i % 4; + + let shift_amt = bitcast(idx_in_elem * 8); + + let mask = 0xff << shift_amt; + let shifted_val = (val << shift_amt) & mask; + + // First zero out the previous value using the inverted mask. + atomicAnd(&((*buffer)[elem_idx]), ~mask); + // And then write the new value. + atomicOr(&((*buffer)[elem_idx]), shifted_val); +} + +fn write_bool(buffer: ptr>, read_write>, + i: i32, + val: bool +) { + if val { write_i8(buffer, i, 1); } + else { write_i8(buffer, i, 0); } +} + +fn add_i8(a: i8, b: i8) -> i8 { + return norm_i8(a + b); +} + +fn neg_i8(a: i8) -> i8 { + return add_i8(~a, 1); +} + +fn sub_i8(a: i8, b: i8) -> i8 { + return add_i8(a, neg_i8(b)); +} + +fn mul_i8(a: i8, b: i8) -> i8 { + return norm_i8(a * b); +} + +fn udiv_i8(a: i8, b: i8) -> i8 { + return norm_i8(udiv_i32(norm_u8(a), norm_u8(b))); +} + +fn udiv_up_i8(a: i8, b: i8) -> i8 { + return norm_i8(udiv_up_i32(norm_u8(a), norm_u8(b))); +} + +fn sdiv_i8(a: i8, b: i8) -> i8 { + return sdiv_i32(a, b); +} + +fn sdiv_up_i8(a: i8, b: i8) -> i8 { + return sdiv_up_i32(a, b); +} + +fn umod_i8(a: i8, b: i8) -> i8 { + return norm_i8(umod_i32(norm_u8(a), norm_u8(b))); +} + +fn smod_i8(a: i8, b: i8) -> i8 { + return smod_i32(a, b); +} + +fn umin_i8(a: i8, b: i8) -> i8 { + return umin_i32(a, b); +} + +fn umax_i8(a: i8, b: i8) -> i8 { + return umax_i32(a, b); +} + +fn shl_i8(a: i8, b: i8) -> i8 { + return a << bitcast(b); +} + +fn lshr_i8(a: i8, b: i8) -> i8 { + return bitcast(bitcast(a) >> bitcast(b)); +} + +fn ashr_i8(a: i8, b: i8) -> i8 { + return a >> bitcast(b); +} + +fn pow_i8(a_p: i8, b: i8) -> i8 { + var a = a_p; + var res: i8 = 1; + var rem: i8 = b; + + while rem != 0 { + if (rem & 1) != 0 { + res = mul_i8(res, a); + } + rem = ashr_i8(rem, 1); + a = mul_i8(a, a); + } + + return res; +} + +fn zext_i8_i16(a: i8) -> i16 { + return a & 0xff; +} + +fn sext_i8_i16(a: i8) -> i16 { + return a; +} + +fn zext_i8_i32(a: i8) -> i32 { + return a & 0xff; +} + +fn sext_i8_i32(a: i8) -> i32 { + // The representation is already a sign-extended i32. + return a; +} + +fn zext_i8_i64(a: i8) -> i64 { + return i64(zext_i8_i32(a), 0); +} + +fn sext_i8_i64(a: i8) -> i64 { + return sext_i32_i64(a); +} + +fn trunc_i16_i8(a: i32) -> i8 { + return a & 0xff; +} + +fn trunc_i32_i8(a: i32) -> i8 { + return a & 0xff; +} + +fn trunc_i64_i8(a: i64) -> i8 { + return trunc_i32_i8(a.x); +} + +fn f16_to_u8(a: f16) -> i8 { + return norm_u8(i32(u32(a))); +} + +fn f32_to_u8(a: f32) -> i8 { + return norm_u8(i32(u32(a))); +} + +fn f16_to_i8(a: f16) -> i8 { + return norm_i8(i32(a)); +} + +fn f32_to_i8(a: f32) -> i8 { + return norm_i8(i32(a)); +} + +fn u8_to_f16(a: i8) -> f16 { + return f16(bitcast(a)); +} + +fn u8_to_f32(a: i8) -> f32 { + return f32(bitcast(a)); +} + +fn bool_to_i8(a: bool) -> i8 { + if a { return 1; } else { return 0; } +} + +// End of scalar8.wgsl diff --git a/shell.nix b/shell.nix index a57b73780c..d8ad9c71a8 100644 --- a/shell.nix +++ b/shell.nix @@ -4,9 +4,13 @@ let pkgs = import sources.nixpkgs {}; python = pkgs.python313Packages; haskell = pkgs.haskell.packages.ghc98; + PWD = builtins.getEnv "PWD"; in pkgs.stdenv.mkDerivation { name = "futhark"; + + EM_CACHE = "${PWD}/em_cache"; + buildInputs = with pkgs; [ @@ -42,6 +46,14 @@ pkgs.stdenv.mkDerivation { python.sphinx python.sphinxcontrib-bibtex imagemagick # needed for literate tests + ] + # The following are for WebGPU. + ++ [ + emscripten + python3Packages.aiohttp + python3Packages.selenium + chromium + chromedriver ] ++ lib.optionals (stdenv.isLinux) [ opencl-headers diff --git a/src/Futhark/Actions.hs b/src/Futhark/Actions.hs index 878d0f039f..f7ed7c9e17 100644 --- a/src/Futhark/Actions.hs +++ b/src/Futhark/Actions.hs @@ -12,6 +12,8 @@ module Futhark.Actions impCodeGenAction, kernelImpCodeGenAction, multicoreImpCodeGenAction, + webgpuImpCodeGenAction, + webgpuTestKernelsAction, metricsAction, compileCAction, compileCtoWASMAction, @@ -23,6 +25,7 @@ module Futhark.Actions compileMulticoreToWASMAction, compilePythonAction, compilePyOpenCLAction, + compileWebGPUAction, ) where @@ -44,6 +47,7 @@ import Futhark.Analysis.MemAlias qualified as MemAlias import Futhark.Analysis.Metrics import Futhark.CodeGen.Backends.CCUDA qualified as CCUDA import Futhark.CodeGen.Backends.COpenCL qualified as COpenCL +import Futhark.CodeGen.Backends.CWebGPU qualified as CWebGPU import Futhark.CodeGen.Backends.HIP qualified as HIP import Futhark.CodeGen.Backends.MulticoreC qualified as MulticoreC import Futhark.CodeGen.Backends.MulticoreISPC qualified as MulticoreISPC @@ -55,6 +59,7 @@ import Futhark.CodeGen.Backends.SequentialWASM qualified as SequentialWASM import Futhark.CodeGen.ImpGen.GPU qualified as ImpGenGPU import Futhark.CodeGen.ImpGen.Multicore qualified as ImpGenMulticore import Futhark.CodeGen.ImpGen.Sequential qualified as ImpGenSequential +import Futhark.CodeGen.ImpGen.WebGPU qualified as ImpGenWebGPU import Futhark.Compiler.CLI import Futhark.IR import Futhark.IR.GPUMem (GPUMem) @@ -62,6 +67,7 @@ import Futhark.IR.MCMem (MCMem) import Futhark.IR.SOACS (SOACS) import Futhark.IR.SeqMem (SeqMem) import Futhark.Optimise.Fusion.GraphRep qualified +import Futhark.Test.WebGPUTest qualified as WebGPUTest import Futhark.Util (runProgramWithExitCode, unixEnvironment) import Futhark.Util.Pretty (Doc, pretty, putDocLn, ()) import Futhark.Version (versionString) @@ -197,6 +203,24 @@ multicoreImpCodeGenAction = actionProcedure = liftIO . putStrLn . prettyString . snd <=< ImpGenMulticore.compileProg } +-- | Convert the program to WebGPU ImpCode and print it to stdout. +webgpuImpCodeGenAction :: Action GPUMem +webgpuImpCodeGenAction = + Action + { actionName = "Compile imperative WebGPU", + actionDescription = "Translate program into imperative IL with WebGPU and write it on standard output.", + actionProcedure = liftIO . putStrLn . prettyString . snd <=< ImpGenWebGPU.compileProg + } + +-- | Convert the program to WebGPU ImpCode and generate test runner input. +webgpuTestKernelsAction :: FilePath -> Action GPUMem +webgpuTestKernelsAction f = + Action + { actionName = "Setup test for WebGPU WGSL kernels", + actionDescription = "Translate program into imperative IL with WebGPU and write it on standard output.", + actionProcedure = liftIO . putStrLn . prettyString <=< WebGPUTest.generateTests f + } + -- Lines that we prepend (in comments) to generated code. headerLines :: [T.Text] headerLines = T.lines $ "Generated by Futhark " <> versionString @@ -519,30 +543,97 @@ compilePyOpenCLAction fcfg mode outpath = actionProcedure = pythonCommon PyOpenCL.compileProg fcfg mode outpath } +-- | The @futhark webgpu@ action. +compileWebGPUAction :: FutharkConfig -> CompilerMode -> FilePath -> Action GPUMem +compileWebGPUAction fcfg mode tgtpath = + Action + { actionName = "Compile to WebGPU", + actionDescription = "Compile to WebGPU", + actionProcedure = helper + } + where + helper prog = do + (cprog, jslib, exports) <- + handleWarnings fcfg $ CWebGPU.compileProg versionString prog + let outpath = + if takeExtension tgtpath == "js" + then tgtpath + else tgtpath `addExtension` "js" + cpath = tgtpath `addExtension` "c" + jslibpath = tgtpath `addExtension` "wrapper.js" + jsserverpath = tgtpath `addExtension` "server.js" + jsonpath = tgtpath `addExtension` "json" + extra_options = + [ "-sUSE_WEBGPU", + "-sASYNCIFY", + "-sMODULARIZE", + "-sWASM_BIGINT", + "-sASSERTIONS", + "-sALLOW_MEMORY_GROWTH", + "-sEXPORTED_RUNTIME_METHODS=cwrap,ccall,Asyncify,HEAP8,HEAPU8,HEAP16,HEAPU16,HEAP32,HEAPU32,HEAP64,HEAPU64,HEAPF32,HEAPF64", + "--extern-post-js", + jslibpath + ] + export_option = + "-sEXPORTED_FUNCTIONS=" + ++ intercalate "," ['_' : T.unpack e | e <- exports] + case mode of + ToLibrary -> do + let (_header, impl, manifest) = CWebGPU.asLibrary cprog + liftIO $ T.writeFile cpath $ cPrependHeader impl + liftIO $ T.writeFile jslibpath jslib + liftIO $ T.writeFile jsonpath manifest + runEMCC + cpath + outpath + ["-O3", "-std=c99"] + ["-lm"] + (export_option : extra_options) + ToExecutable -> do + liftIO $ T.writeFile cpath $ cPrependHeader $ CWebGPU.asExecutable cprog + liftIO $ T.writeFile jslibpath jslib + runEMCC cpath outpath ["-O3", "-std=c99"] ["-lm"] extra_options + ToServer -> do + let (impl, server) = CWebGPU.asJSServer cprog + liftIO $ T.writeFile cpath $ cPrependHeader impl + liftIO $ T.writeFile jslibpath jslib + liftIO $ T.writeFile jsserverpath server + let serverArgs = ["--extern-post-js", jsserverpath] + runEMCC + cpath + outpath + ["-O3", "-std=c99"] + ["-lm"] + (export_option : extra_options ++ serverArgs) + cmdEMCFLAGS :: [String] -> [String] cmdEMCFLAGS def = maybe def words $ lookup "EMCFLAGS" unixEnvironment -runEMCC :: String -> String -> FilePath -> [String] -> [String] -> [String] -> Bool -> FutharkM () -runEMCC cpath outpath classpath cflags_def ldflags expfuns lib = do +wasmFlags :: FilePath -> [String] -> Bool -> [String] +wasmFlags classpath expfuns lib = + ["-lnodefs.js"] + ++ ["-s", "--extern-post-js", classpath] + ++ ( if lib + then ["-s", "EXPORT_NAME=loadWASM"] + else [] + ) + ++ ["-s", "WASM_BIGINT"] + ++ [ "-s", + "EXPORTED_FUNCTIONS=[" + ++ intercalate "," ("'_malloc'" : "'_free'" : expfuns) + ++ "]" + ] + +runEMCC :: String -> String -> [String] -> [String] -> [String] -> FutharkM () +runEMCC cpath outpath cflags_def ldflags extra_flags = do ret <- liftIO $ runProgramWithExitCode "emcc" ( [cpath, "-o", outpath] - ++ ["-lnodefs.js"] - ++ ["-s", "--extern-post-js", classpath] - ++ ( if lib - then ["-s", "EXPORT_NAME=loadWASM"] - else [] - ) - ++ ["-s", "WASM_BIGINT"] + ++ extra_flags ++ cmdCFLAGS cflags_def ++ cmdEMCFLAGS [""] - ++ [ "-s", - "EXPORTED_FUNCTIONS=[" - ++ intercalate "," ("'_malloc'" : "'_free'" : expfuns) - ++ "]" - ] -- The default LDFLAGS are always added. ++ ldflags ) @@ -575,12 +666,22 @@ compileCtoWASMAction fcfg mode outpath = ToLibrary -> do writeLibs cprog jsprog liftIO $ T.appendFile classpath SequentialWASM.libraryExports - runEMCC cpath mjspath classpath ["-O3", "-msimd128"] ["-lm"] exps True + runEMCC + cpath + mjspath + ["-O3", "-msimd128"] + ["-lm"] + (wasmFlags classpath exps True) _ -> do -- Non-server executables are not supported. writeLibs cprog jsprog liftIO $ T.appendFile classpath SequentialWASM.runServer - runEMCC cpath outpath classpath ["-O3", "-msimd128"] ["-lm"] exps False + runEMCC + cpath + outpath + ["-O3", "-msimd128"] + ["-lm"] + (wasmFlags classpath exps False) writeLibs cprog jsprog = do let (h, imp, _) = SequentialC.asLibrary cprog liftIO $ T.writeFile hpath h @@ -609,12 +710,22 @@ compileMulticoreToWASMAction fcfg mode outpath = ToLibrary -> do writeLibs cprog jsprog liftIO $ T.appendFile classpath MulticoreWASM.libraryExports - runEMCC cpath mjspath classpath ["-O3", "-msimd128"] ["-lm", "-pthread"] exps True + runEMCC + cpath + mjspath + ["-O3", "-msimd128"] + ["-lm", "-pthread"] + (wasmFlags classpath exps True) _ -> do -- Non-server executables are not supported. writeLibs cprog jsprog liftIO $ T.appendFile classpath MulticoreWASM.runServer - runEMCC cpath outpath classpath ["-O3", "-msimd128"] ["-lm", "-pthread"] exps False + runEMCC + cpath + outpath + ["-O3", "-msimd128"] + ["-lm", "-pthread"] + (wasmFlags classpath exps False) writeLibs cprog jsprog = do let (h, imp, _) = MulticoreC.asLibrary cprog diff --git a/src/Futhark/CLI/Dev.hs b/src/Futhark/CLI/Dev.hs index 37e1e1ce69..ea8cde56a2 100644 --- a/src/Futhark/CLI/Dev.hs +++ b/src/Futhark/CLI/Dev.hs @@ -514,6 +514,23 @@ commandLineOptions = opts {futharkAction = GPUMemAction $ \_ _ _ -> kernelImpCodeGenAction} ) "Translate pipeline result to ImpGPU and write it on stdout.", + Option + [] + ["compile-imp-webgpu"] + ( NoArg $ + Right $ \opts -> + opts {futharkAction = GPUMemAction $ \_ _ _ -> webgpuImpCodeGenAction} + ) + "Translate pipeline result to ImpWebGPU and write it on stdout.", + Option + [] + ["test-webgpu-kernels"] + ( NoArg $ + Right $ \opts -> + -- type BackendAction rep = FutharkConfig -> CompilerMode -> FilePath -> Action rep + opts {futharkAction = GPUMemAction $ \_ _ -> webgpuTestKernelsAction} + ) + "Translate pipeline result to ImpWebGPU and generate test runner input.", Option [] ["compile-imp-multicore"] diff --git a/src/Futhark/CLI/Main.hs b/src/Futhark/CLI/Main.hs index b4814d9b2e..98da3cbd8d 100644 --- a/src/Futhark/CLI/Main.hs +++ b/src/Futhark/CLI/Main.hs @@ -37,6 +37,7 @@ import Futhark.CLI.Run qualified as Run import Futhark.CLI.Script qualified as Script import Futhark.CLI.Test qualified as Test import Futhark.CLI.WASM qualified as WASM +import Futhark.CLI.WebGPU qualified as WebGPU import Futhark.Error import Futhark.Util (maxinum, showText) import Futhark.Util.Options @@ -66,6 +67,7 @@ commands = ("pyopencl", (PyOpenCL.main, "Compile to Python calling PyOpenCL.")), ("wasm", (WASM.main, "Compile to WASM with sequential C")), ("wasm-multicore", (MulticoreWASM.main, "Compile to WASM with multicore C")), + ("webgpu", (WebGPU.main, "Compile to JS/WASM calling WebGPU.")), ("ispc", (MulticoreISPC.main, "Compile to multicore ISPC")), ("test", (Test.main, "Test Futhark programs.")), ("bench", (Bench.main, "Benchmark Futhark programs.")), diff --git a/src/Futhark/CLI/WebGPU.hs b/src/Futhark/CLI/WebGPU.hs new file mode 100644 index 0000000000..cc8c9d2ecb --- /dev/null +++ b/src/Futhark/CLI/WebGPU.hs @@ -0,0 +1,17 @@ +-- | @futhark webgpu@ +module Futhark.CLI.WebGPU (main) where + +import Futhark.Actions (compileWebGPUAction) +import Futhark.Compiler.CLI +import Futhark.Passes (gpumemPipeline) + +-- | Run @futhark webgpu@. +main :: String -> [String] -> IO () +main = compilerMain + () + [] + "Compile WebGPU" + "Generate WebGPU C code from optimised Futhark program." + gpumemPipeline + $ \fcfg () mode outpath prog -> + actionProcedure (compileWebGPUAction fcfg mode outpath) prog diff --git a/src/Futhark/CodeGen/Backends/CCUDA.hs b/src/Futhark/CodeGen/Backends/CCUDA.hs index f188fbd8c3..80346ea2f6 100644 --- a/src/Futhark/CodeGen/Backends/CCUDA.hs +++ b/src/Futhark/CodeGen/Backends/CCUDA.hs @@ -22,6 +22,7 @@ import Futhark.IR.GPUMem hiding ( CmpSizeLe, GetSize, GetSizeMax, + HostOp, ) import Futhark.MonadFreshNames import Language.C.Quote.OpenCL qualified as C @@ -33,7 +34,7 @@ mkBoilerplate :: M.Map Name KernelSafety -> [PrimType] -> [FailureMsg] -> - GC.CompilerM OpenCL () () + GC.CompilerM HostOp () () mkBoilerplate cuda_program macros kernels types failures = do generateGPUBoilerplate cuda_program @@ -102,7 +103,7 @@ cliOptions = } ] -cudaMemoryType :: GC.MemoryType OpenCL () +cudaMemoryType :: GC.MemoryType HostOp () cudaMemoryType "device" = pure [C.cty|typename CUdeviceptr|] cudaMemoryType space = error $ "GPU backend does not support '" ++ space ++ "' memory space." @@ -125,7 +126,7 @@ compileProg version prog = do cliOptions prog' where - operations :: GC.Operations OpenCL () + operations :: GC.Operations HostOp () operations = gpuOperations { GC.opsMemoryType = cudaMemoryType, diff --git a/src/Futhark/CodeGen/Backends/COpenCL.hs b/src/Futhark/CodeGen/Backends/COpenCL.hs index 3d57fb6c9e..07188ccdbe 100644 --- a/src/Futhark/CodeGen/Backends/COpenCL.hs +++ b/src/Futhark/CodeGen/Backends/COpenCL.hs @@ -24,6 +24,7 @@ import Futhark.IR.GPUMem hiding ( CmpSizeLe, GetSize, GetSizeMax, + HostOp, ) import Futhark.MonadFreshNames import Language.C.Quote.OpenCL qualified as C @@ -81,7 +82,7 @@ mkBoilerplate :: M.Map Name KernelSafety -> [PrimType] -> [FailureMsg] -> - GC.CompilerM OpenCL () () + GC.CompilerM HostOp () () mkBoilerplate opencl_program macros kernels types failures = do generateGPUBoilerplate opencl_program @@ -176,7 +177,7 @@ cliOptions = } ] -openclMemoryType :: GC.MemoryType OpenCL () +openclMemoryType :: GC.MemoryType HostOp () openclMemoryType "device" = pure [C.cty|typename cl_mem|] openclMemoryType space = error $ "GPU backend does not support '" ++ space ++ "' memory space." @@ -199,7 +200,7 @@ compileProg version prog = do cliOptions prog' where - operations :: GC.Operations OpenCL () + operations :: GC.Operations HostOp () operations = gpuOperations { GC.opsMemoryType = openclMemoryType diff --git a/src/Futhark/CodeGen/Backends/CWebGPU.hs b/src/Futhark/CodeGen/Backends/CWebGPU.hs new file mode 100644 index 0000000000..42d9cabdcc --- /dev/null +++ b/src/Futhark/CodeGen/Backends/CWebGPU.hs @@ -0,0 +1,436 @@ +{-# LANGUAGE QuasiQuotes #-} + +-- | Code generation for WebGPU. +module Futhark.CodeGen.Backends.CWebGPU + ( compileProg, + GC.CParts (..), + GC.asLibrary, + GC.asExecutable, + GC.asServer, + asJSServer, + ) +where + +import Data.Map qualified as M +import Data.Maybe (mapMaybe) +import Data.Set qualified as S +import Data.Text qualified as T +import Futhark.CodeGen.Backends.GPU +import Futhark.CodeGen.Backends.GenericC qualified as GC +import Futhark.CodeGen.Backends.GenericC.Options +import Futhark.CodeGen.Backends.GenericC.Pretty (idText) +import Futhark.CodeGen.ImpCode.WebGPU +import Futhark.CodeGen.ImpGen.WebGPU qualified as ImpGen +import Futhark.CodeGen.RTS.C (backendsWebGPUH) +import Futhark.CodeGen.RTS.WebGPU (serverWsJs, utilJs, valuesJs, wrappersJs) +import Futhark.CodeGen.RTS.WGSL qualified as RTS +import Futhark.IR.GPUMem (GPUMem, Prog) +import Futhark.MonadFreshNames +import Futhark.Util (chunk) +import Language.C.Quote.C qualified as C +import NeatInterpolation (text, untrimming) + +mkKernelInfos :: M.Map Name KernelInterface -> GC.CompilerM HostOp () () +mkKernelInfos kernels = do + mapM_ + GC.earlyDecl + [C.cunit|typedef struct wgpu_kernel_info { + char *name; + typename size_t num_scalars; + typename size_t scalars_binding; + typename size_t scalars_size; + typename size_t *scalar_offsets; + typename size_t num_bindings; + typename uint32_t *binding_indices; + typename size_t num_overrides; + char **used_overrides; + typename size_t num_dynamic_block_dims; + typename uint32_t *dynamic_block_dim_indices; + char **dynamic_block_dim_names; + typename uint32_t num_shared_mem_overrides; + char **shared_mem_overrides; + const char **gpu_program; + } wgpu_kernel_info; + static typename size_t wgpu_num_kernel_infos = $exp:num_kernels; |] + mapM_ GC.earlyDecl $ concatMap sc_offs_decl (M.toList kernels) + mapM_ GC.earlyDecl $ concatMap bind_idxs_decl (M.toList kernels) + mapM_ GC.earlyDecl $ concatMap used_overrides_decl (M.toList kernels) + mapM_ GC.earlyDecl $ concatMap dynamic_block_dim_indices_decl (M.toList kernels) + mapM_ GC.earlyDecl $ concatMap dynamic_block_dim_names_decl (M.toList kernels) + mapM_ GC.earlyDecl $ concatMap shared_mem_overrides_decl (M.toList kernels) + mapM_ GC.earlyDecl $ concatMap gpu_programs (M.toList kernels) + mapM_ + GC.earlyDecl + [C.cunit|static struct wgpu_kernel_info wgpu_kernel_infos[] + = {$inits:info_inits};|] + where + num_kernels = M.size kernels + sc_offs_decl (n, k) = + let offs = map (\o -> [C.cinit|$int:o|]) (scalarsOffsets k) + in [C.cunit|static typename size_t $id:(n <> "_scalar_offsets")[] + = {$inits:offs};|] + bind_idxs_decl (n, k) = + let idxs = map (\i -> [C.cinit|$int:i|]) (memBindSlots k) + in [C.cunit|static typename uint32_t $id:(n <> "_binding_indices")[] + = {$inits:idxs};|] + used_overrides_decl (n, k) = + let overrides = + map + (\o -> [C.cinit|$string:(T.unpack o)|]) + (overrideNames k) + in [C.cunit|static char* $id:(n <> "_used_overrides")[] + = {$inits:overrides};|] + dynamic_block_dim_indices_decl (n, k) = + let idxs = map ((\i -> [C.cinit|$int:i|]) . fst) (dynamicBlockDims k) + in [C.cunit|static typename uint32_t $id:(n <> "_dynamic_block_dim_indices")[] + = {$inits:idxs};|] + dynamic_block_dim_names_decl (n, k) = + let names = + map + ((\d -> [C.cinit|$string:(T.unpack d)|]) . snd) + (dynamicBlockDims k) + in [C.cunit|static char* $id:(n <> "_dynamic_block_dim_names")[] + = {$inits:names};|] + shared_mem_overrides_decl (n, k) = + let names = + map + (\d -> [C.cinit|$string:(T.unpack d)|]) + (sharedMemoryOverrides k) + in [C.cunit|static char* $id:(n <> "_shared_mem_overrides")[] + = {$inits:names};|] + gpu_programs (n, k) = + let program_fragments = [[C.cinit|$string:s|] | s <- chunk 2000 $ T.unpack $ gpuProgram k] + in [C.cunit|static const char* $id:(n <> "_gpu_program")[] + = {$inits:program_fragments, NULL};|] + + info_init (n, k) = + let num_scalars = length (scalarsOffsets k) + num_bindings = length (memBindSlots k) + num_overrides = length (overrideNames k) + num_dynamic_block_dims = length (dynamicBlockDims k) + num_shared_mem_overrides = length (sharedMemoryOverrides k) + in [C.cinit|{ .name = $string:(T.unpack (idText (C.toIdent n mempty))), + .num_scalars = $int:num_scalars, + .scalars_binding = $int:(scalarsBindSlot k), + .scalars_size = $int:(scalarsSize k), + .scalar_offsets = $id:(n <> "_scalar_offsets"), + .num_bindings = $int:num_bindings, + .binding_indices = $id:(n <> "_binding_indices"), + .num_overrides = $int:num_overrides, + .used_overrides = $id:(n <> "_used_overrides"), + .num_dynamic_block_dims = $int:num_dynamic_block_dims, + .dynamic_block_dim_indices = + $id:(n <> "_dynamic_block_dim_indices"), + .dynamic_block_dim_names = + $id:(n <> "_dynamic_block_dim_names"), + .num_shared_mem_overrides = $int:num_shared_mem_overrides, + .shared_mem_overrides = + $id:(n <> "_shared_mem_overrides"), + .gpu_program = $id:(n <> "_gpu_program") + }|] + info_inits = map info_init (M.toList kernels) + +-- We need to generate kernel_infos for built-in kernels as well. +builtinKernels :: M.Map Name KernelInterface +builtinKernels = + M.fromList $ concatMap generateKernels builtinKernelTemplates + where + builtinKernelTemplates = + [ ("lmad_copy_NAME", copyInterface RTS.lmad_copy) + , ("map_transpose_NAME", transposeInterface RTS.map_transpose) + , ("map_transpose_NAME_low_height", transposeInterface RTS.map_transpose_low_height) + , ("map_transpose_NAME_low_width", transposeInterface RTS.map_transpose_low_width) + , ("map_transpose_NAME_small", transposeInterface RTS.map_transpose_small) + , ("map_transpose_NAME_large", transposeInterfaceLarge RTS.map_transpose_large) + ] + + generateKernelProgram kernel name elemType atomic = + let baseKernel = if atomic + then T.replace "" ">" kernel + else kernel + in RTS.wgsl_prelude <> (T.replace "NAME" name $ T.replace "ELEM_TYPE" elemType baseKernel) + + generateKernels (template, interface) = + [(nameFromText (T.replace "NAME" name template), interface name elemType atomic) + | (name, elemType, atomic) <- [("1b", "i8", True) + ,("2b", "i16", True) + ,("4b", "i32", False) + ,("8b", "i64", False)]] + + transposeInterface program name elemType atomic = + KernelInterface + { safety = SafetyNone, + scalarsOffsets = [0, 8, 16, 20, 24, 28, 32, 36, 40], + scalarsSize = 48, -- uniform buffers must be multiple of 16 bytes + scalarsBindSlot = 0, + memBindSlots = [1, 2], + overrideNames = ["block_size_x", "block_size_y", "block_size_z"], + dynamicBlockDims = + [ (0, "block_size_x"), + (1, "block_size_y"), + (2, "block_size_z") + ], + sharedMemoryOverrides = [], + gpuProgram = generateKernelProgram program name elemType atomic + } + + transposeInterfaceLarge program name elemType atomic = + KernelInterface + { safety = SafetyNone, + scalarsOffsets = [0, 8, 16, 24, 32, 40, 48, 56, 60], + scalarsSize = 64, -- uniform buffers must be multiple of 16 bytes + scalarsBindSlot = 0, + memBindSlots = [1, 2], + overrideNames = ["block_size_x", "block_size_y", "block_size_z"], + dynamicBlockDims = + [ (0, "block_size_x"), + (1, "block_size_y"), + (2, "block_size_z") + ], + sharedMemoryOverrides = [], + gpuProgram = generateKernelProgram program name elemType atomic + } + + copyInterface program name elemType atomic = + KernelInterface + { safety = SafetyNone, + -- note that we need to align the 'r' field to 8-bytes + -- despite it being an i32 due to uniform buffer alignment + -- requirements + scalarsOffsets = [0, 8 .. 216], + scalarsSize = 224, -- uniform buffers must be multiple of 16 bytes + scalarsBindSlot = 0, + memBindSlots = [1, 2], + overrideNames = ["block_size_x", "block_size_y", "block_size_z"], + dynamicBlockDims = + [ (0, "block_size_x"), + (1, "block_size_y"), + (2, "block_size_z") + ], + sharedMemoryOverrides = [], + gpuProgram = generateKernelProgram program name elemType atomic + } + +mkBoilerplate :: + T.Text -> + [(Name, KernelConstExp)] -> + M.Map Name KernelInterface -> + [PrimType] -> + [FailureMsg] -> + GC.CompilerM HostOp () () +mkBoilerplate wgsl_program macros kernels types failures = do + mkKernelInfos (M.union kernels builtinKernels) + generateGPUBoilerplate + wgsl_program + macros + backendsWebGPUH + (M.keys kernels) + types + failures + + GC.headerDecl GC.InitDecl [C.cedecl|const char* futhark_context_config_get_program(struct futhark_context_config *cfg);|] + GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_set_program(struct futhark_context_config *cfg, const char* s);|] + +-- TODO: Check GPU.gpuOptions and see which of these make sense for us to +-- support. +cliOptions :: [Option] +cliOptions = [] + +webgpuMemoryType :: GC.MemoryType HostOp () +webgpuMemoryType "device" = pure [C.cty|typename WGPUBuffer|] +webgpuMemoryType space = error $ "WebGPU backend does not support '" ++ space ++ "' memory space." + +jsBoilerplate :: Definitions a -> T.Text -> (T.Text, [T.Text]) +jsBoilerplate prog manifest = + let (context, exports) = mkJsContext prog manifest + prelude = T.intercalate "\n" [utilJs, valuesJs, wrappersJs] + in (prelude <> "\n" <> context, exports ++ builtinExports) + where + builtinExports = + [ "malloc", + "free", + "futhark_context_config_new", + "futhark_context_new", + "futhark_context_config_free", + "futhark_context_free", + "futhark_context_sync", + "futhark_context_clear_caches", + "futhark_context_report", + "futhark_context_pause_profiling", + "futhark_context_unpause_profiling" + ] + +-- Argument should be a direct function call to a WASM function that needs to +-- be handled asynchronously. The return value evaluates to a Promise yielding +-- the function result when awaited. +-- Can currently only be used for code generated into the FutharkModule class. +asyncCall :: T.Text -> Bool -> [T.Text] -> T.Text +asyncCall func hasReturn args = + [text|this.m.ccall('${func}', ${ret}, ${argTypes}, ${argList}, {async: true})|] + where + ret = if hasReturn then "'number'" else "null" + argTypes = + "[" <> T.intercalate ", " (replicate (length args) "'number'") <> "]" + argList = + "[" <> T.intercalate ", " args <> "]" + +mkJsContext :: Definitions a -> T.Text -> (T.Text, [T.Text]) +mkJsContext (Definitions _ _ (Functions funs)) manifest = + ( [text| + class FutharkModule { + ${constructor} + ${free} + ${builtins} + }|], + entryExports ++ valueExports + ) + where + constructor = + [text| + constructor() { + this.m = undefined; + this.manifest = ${manifest}; + } + async init(module) { + this.m = module; + this.cfg = this.m._futhark_context_config_new(); + this.ctx = await ${newContext}; + this.entry = {}; + this.types = {}; + ${valueClasses} + ${entryPointFuns} + }|] + newContext = asyncCall "futhark_context_new" True ["this.cfg"] + free = + [text| + free() { + this.m._futhark_context_free(this.ctx); + this.m._futhark_context_config_free(this.cfg); + }|] + entryPoints = mapMaybe (functionEntry . snd) funs + (entryPointFuns, entryExports) = mkJsEntryPoints entryPoints + (valueClasses, valueExports) = mkJsValueClasses entryPoints + builtins = + [text| + malloc(nbytes) { + return this.m._malloc(nbytes); + } + free(ptr) { + return this.m._free(ptr); + } + async context_sync() { + return await ${syncCall}; + } + async clear_caches() { + return await ${clearCall}; + } + async report() { + return await this.m.ccall('futhark_context_report', 'string', + ['number'], [this.ctx], {async: true}); + } + async pause_profiling() { + return await ${pauseProfilingCall}; + } + async unpause_profiling() { + return await ${unpauseProfilingCall}; + } + |] + syncCall = asyncCall "futhark_context_sync" True ["this.ctx"] + clearCall = asyncCall "futhark_context_clear_caches" True ["this.ctx"] + pauseProfilingCall = + asyncCall "futhark_context_pause_profiling" False ["this.ctx"] + unpauseProfilingCall = + asyncCall "futhark_context_unpause_profiling" False ["this.ctx"] + +mkJsEntryPoints :: [EntryPoint] -> (T.Text, [T.Text]) +mkJsEntryPoints entries = (T.intercalate "\n" entryFuns, entryExports) + where + entryNames = map (nameToText . entryPointName) entries + entryFuns = map entryFun entryNames + entryExports = map entryExport entryNames + entryFun name = + [text|this.entry['${name}'] = make_entry_function(this, '${name}').bind(this);|] + entryExport name = "futhark_entry_" <> name + +mkJsValueClasses :: [EntryPoint] -> (T.Text, [T.Text]) +mkJsValueClasses entries = + -- TODO: Only supports transparent arrays right now. + let extVals = + concatMap (map snd . entryPointResults) entries + ++ concatMap (map snd . entryPointArgs) entries + transpVals = [v | TransparentValue v <- extVals] + arrVals = + S.toList $ + S.fromList + [(typ, sgn, shp) | ArrayValue _ _ typ sgn shp <- transpVals] + (cls, exports) = unzip $ map mkJsArrayClass arrVals + in (T.intercalate "\n" cls, concat exports) + +mkJsArrayClass :: (PrimType, Signedness, [DimSize]) -> (T.Text, [T.Text]) +mkJsArrayClass (typ, sign, shp) = + ( [text| + this.${name} = make_array_class(this, '${prettyName}'); + this.types['${prettyName}'] = this.${name}; + |], + exports + ) + where + rank = length shp + elemName = prettySigned (sign == Unsigned) typ + prettyName = mconcat (replicate rank "[]") <> elemName + name = elemName <> "_" <> prettyText rank <> "d" + exports = + [ "futhark_new_" <> name, + "futhark_free_" <> name, + "futhark_values_" <> name, + "futhark_shape_" <> name + ] + +-- | Compile the program to C with calls to WebGPU, along with a JS wrapper +-- library. +compileProg :: + (MonadFreshNames m) => + T.Text -> + Prog GPUMem -> + m (ImpGen.Warnings, (GC.CParts, T.Text, [T.Text])) +compileProg version prog = do + ( ws, + Program wgsl_code wgsl_prelude macros kernels params failures prog' + ) <- + ImpGen.compileProg prog + c <- + GC.compileProg + "webgpu" + version + params + operations + (mkBoilerplate (wgsl_prelude <> wgsl_code) macros kernels [] failures) + webgpu_includes + (Space "device", [Space "device", DefaultSpace]) + cliOptions + prog' + let (js, exports) = jsBoilerplate prog' (GC.cJsonManifest c) + pure (ws, (c, js, exports)) + where + operations :: GC.Operations HostOp () + operations = + gpuOperations + { GC.opsMemoryType = webgpuMemoryType + } + webgpu_includes = + [untrimming| + #ifdef USE_DAWN + #include + #else + #include + #include + #include + #endif + |] + +-- | As server script. Speaks custom protocol to local Python server +-- wrapper that speaks the actual Futhark server protocol. +asJSServer :: GC.CParts -> (T.Text, T.Text) +asJSServer parts = + let (_, c, _) = GC.asLibrary parts + in (c, serverWsJs) diff --git a/src/Futhark/CodeGen/Backends/GPU.hs b/src/Futhark/CodeGen/Backends/GPU.hs index cbae421e3b..c8774be19b 100644 --- a/src/Futhark/CodeGen/Backends/GPU.hs +++ b/src/Futhark/CodeGen/Backends/GPU.hs @@ -20,7 +20,7 @@ import Futhark.CodeGen.Backends.GenericC qualified as GC import Futhark.CodeGen.Backends.GenericC.Options import Futhark.CodeGen.Backends.GenericC.Pretty (expText, idText) import Futhark.CodeGen.Backends.SimpleRep (primStorageType, toStorage) -import Futhark.CodeGen.ImpCode.OpenCL +import Futhark.CodeGen.ImpCode.Kernels import Futhark.CodeGen.RTS.C (gpuH, gpuPrototypesH) import Futhark.MonadFreshNames import Futhark.Util (chunk) @@ -137,7 +137,7 @@ genLaunchKernel safety kernel_name shared_memory args num_tblocks tblock_size = v' ) -callKernel :: GC.OpCompiler OpenCL () +callKernel :: GC.OpCompiler HostOp () callKernel (GetSize v key) = GC.stm [C.cstm|$id:v = $exp:(getParamByKey key);|] callKernel (CmpSizeLe v key x) = do @@ -296,7 +296,7 @@ syncArg :: GC.CopyBarrier -> C.Exp syncArg GC.CopyBarrier = [C.cexp|true|] syncArg GC.CopyNoBarrier = [C.cexp|false|] -copyGPU :: GC.Copy OpenCL () +copyGPU :: GC.Copy HostOp () copyGPU _ dstmem dstidx (Space "device") srcmem srcidx (Space "device") nbytes = do p <- GC.provenanceExp GC.stm @@ -312,7 +312,7 @@ copyGPU b dstmem dstidx (Space "device") srcmem srcidx DefaultSpace nbytes = do copyGPU _ _ _ destspace _ _ srcspace _ = error $ "Cannot copy to " ++ show destspace ++ " from " ++ show srcspace -gpuOperations :: GC.Operations OpenCL () +gpuOperations :: GC.Operations HostOp () gpuOperations = GC.defaultOperations { GC.opsCompiler = callKernel, @@ -443,7 +443,7 @@ generateGPUBoilerplate :: [Name] -> [PrimType] -> [FailureMsg] -> - GC.CompilerM OpenCL () () + GC.CompilerM HostOp () () generateGPUBoilerplate gpu_program macros backendH kernels types failures = do createKernels kernels let gpu_program_fragments = diff --git a/src/Futhark/CodeGen/Backends/HIP.hs b/src/Futhark/CodeGen/Backends/HIP.hs index b6ae4f78d1..146e8b892a 100644 --- a/src/Futhark/CodeGen/Backends/HIP.hs +++ b/src/Futhark/CodeGen/Backends/HIP.hs @@ -22,6 +22,7 @@ import Futhark.IR.GPUMem hiding ( CmpSizeLe, GetSize, GetSizeMax, + HostOp, ) import Futhark.MonadFreshNames import Language.C.Quote.OpenCL qualified as C @@ -33,7 +34,7 @@ mkBoilerplate :: M.Map Name KernelSafety -> [PrimType] -> [FailureMsg] -> - GC.CompilerM OpenCL () () + GC.CompilerM HostOp () () mkBoilerplate hip_program macros kernels types failures = do generateGPUBoilerplate hip_program @@ -84,7 +85,7 @@ cliOptions = } ] -hipMemoryType :: GC.MemoryType OpenCL () +hipMemoryType :: GC.MemoryType HostOp () hipMemoryType "device" = pure [C.cty|typename hipDeviceptr_t|] hipMemoryType space = error $ "GPU backend does not support '" ++ space ++ "' memory space." @@ -107,7 +108,7 @@ compileProg version prog = do cliOptions prog' where - operations :: GC.Operations OpenCL () + operations :: GC.Operations HostOp () operations = gpuOperations { GC.opsMemoryType = hipMemoryType diff --git a/src/Futhark/CodeGen/Backends/PyOpenCL.hs b/src/Futhark/CodeGen/Backends/PyOpenCL.hs index 394a019b23..1cafade701 100644 --- a/src/Futhark/CodeGen/Backends/PyOpenCL.hs +++ b/src/Futhark/CodeGen/Backends/PyOpenCL.hs @@ -185,7 +185,7 @@ compileProg mode class_name prog = do options prog' where - operations :: Operations Imp.OpenCL () + operations :: Operations Imp.HostOp () operations = Operations { opsCompiler = callKernel, @@ -220,7 +220,7 @@ compileBlockDim :: Imp.BlockDim -> CompilerM op s PyExp compileBlockDim (Left e) = asLong <$> compileExp e compileBlockDim (Right e) = pure $ compileConstExp e -callKernel :: OpCompiler Imp.OpenCL () +callKernel :: OpCompiler Imp.HostOp () callKernel (Imp.GetSize v key) = do v' <- compileVar v stm $ Assign v' $ getParamByKey key @@ -283,7 +283,7 @@ launchKernel kernel_name safety kernel_dims threadblock_dims shared_memory args processKernelArg (Imp.ValueKArg e bt) = toStorage bt <$> compileExp e processKernelArg (Imp.MemKArg v) = compileVar v -writeOpenCLScalar :: WriteScalar Imp.OpenCL () +writeOpenCLScalar :: WriteScalar Imp.HostOp () writeOpenCLScalar mem i bt "device" val = do let nparr = Call @@ -302,7 +302,7 @@ writeOpenCLScalar mem i bt "device" val = do writeOpenCLScalar _ _ _ space _ = error $ "Cannot write to '" ++ space ++ "' memory space." -readOpenCLScalar :: ReadScalar Imp.OpenCL () +readOpenCLScalar :: ReadScalar Imp.HostOp () readOpenCLScalar mem i bt "device" = do val <- newVName "read_res" let val' = Var $ prettyString val @@ -328,7 +328,7 @@ readOpenCLScalar mem i bt "device" = do readOpenCLScalar _ _ _ space = error $ "Cannot read from '" ++ space ++ "' memory space." -allocateOpenCLBuffer :: Allocate Imp.OpenCL () +allocateOpenCLBuffer :: Allocate Imp.HostOp () allocateOpenCLBuffer mem size "device" = stm $ Assign mem $ @@ -336,7 +336,7 @@ allocateOpenCLBuffer mem size "device" = allocateOpenCLBuffer _ _ space = error $ "Cannot allocate in '" ++ space ++ "' space" -packArrayOutput :: EntryOutput Imp.OpenCL () +packArrayOutput :: EntryOutput Imp.HostOp () packArrayOutput mem "device" bt ept dims = do mem' <- compileVar mem dims' <- mapM compileDim dims @@ -351,7 +351,7 @@ packArrayOutput mem "device" bt ept dims = do packArrayOutput _ sid _ _ _ = error $ "Cannot return array from " ++ sid ++ " space." -unpackArrayInput :: EntryInput Imp.OpenCL () +unpackArrayInput :: EntryInput Imp.HostOp () unpackArrayInput mem "device" t s dims e = do let type_is_ok = BinOp diff --git a/src/Futhark/CodeGen/ImpCode/GPU.hs b/src/Futhark/CodeGen/ImpCode/GPU.hs index f303ac0251..564c2adf15 100644 --- a/src/Futhark/CodeGen/ImpCode/GPU.hs +++ b/src/Futhark/CodeGen/ImpCode/GPU.hs @@ -174,6 +174,7 @@ data KernelOp | Barrier Fence | MemFence Fence | SharedAlloc VName (Count Bytes (TExp Int64)) + | UniformRead VName VName (Count Elements (TExp Int64)) PrimType Space | -- | Perform a barrier and also check whether any -- threads have failed an assertion. Make sure all -- threads would reach all 'ErrorSync's if any of them @@ -252,6 +253,16 @@ instance Pretty KernelOp where "error_sync_local()" pretty (ErrorSync FenceGlobal) = "error_sync_global()" + pretty (UniformRead name v is bt space') = + pretty name + <+> "<-" + <+> "read_uniform" + <> parens + ( commasep + [ pretty v <> langle <> pretty bt <> pretty space' <> rangle, + pretty is + ] + ) pretty (Atomic _ (AtomicAdd t old arr ind x)) = pretty old <+> "<-" diff --git a/src/Futhark/CodeGen/ImpCode/Kernels.hs b/src/Futhark/CodeGen/ImpCode/Kernels.hs new file mode 100644 index 0000000000..fc256f2316 --- /dev/null +++ b/src/Futhark/CodeGen/ImpCode/Kernels.hs @@ -0,0 +1,79 @@ +-- | Common definitions for imperative code augmented with the ability to launch +-- kernels. +module Futhark.CodeGen.ImpCode.Kernels + ( KernelName, + KernelArg (..), + HostCode, + HostOp (..), + KernelSafety (..), + numFailureParams, + FailureMsg (..), + BlockDim, + KernelConst (..), + KernelConstExp, + module Futhark.CodeGen.ImpCode, + module Futhark.IR.GPU.Sizes, + ) +where + +import Futhark.CodeGen.ImpCode +import Futhark.CodeGen.ImpCode.GPU (BlockDim, KernelConst (..), KernelConstExp) +import Futhark.IR.GPU.Sizes +import Futhark.Util.Pretty + +-- | Something that can go wrong in a kernel. Part of the machinery +-- for reporting error messages from within kernels. +data FailureMsg = FailureMsg + { failureError :: ErrorMsg Exp, + failureBacktrace :: String + } + +-- | A piece of code calling kernels. +type HostCode = Code HostOp + +-- | The name of a kernel. +type KernelName = Name + +-- | An argument to be passed to a kernel. +data KernelArg + = -- | Pass the value of this scalar expression as argument. + ValueKArg Exp PrimType + | -- | Pass this pointer as argument. + MemKArg VName + deriving (Show) + +-- | Whether a kernel can potentially fail (because it contains bounds +-- checks and such). +data MayFail = MayFail | CannotFail + deriving (Show) + +-- | Information about bounds checks and how sensitive it is to +-- errors. Ordered by least demanding to most. +data KernelSafety + = -- | Does not need to know if we are in a failing state, and also + -- cannot fail. + SafetyNone + | -- | Needs to be told if there's a global failure, and that's it, + -- and cannot fail. + SafetyCheap + | -- | Needs all parameters, may fail itself. + SafetyFull + deriving (Eq, Ord, Show) + +-- | How many leading failure arguments we must pass when launching a +-- kernel with these safety characteristics. +numFailureParams :: KernelSafety -> Int +numFailureParams SafetyNone = 0 +numFailureParams SafetyCheap = 1 +numFailureParams SafetyFull = 3 + +-- | Host-level kernel operation. +data HostOp + = LaunchKernel KernelSafety KernelName (Count Bytes (TExp Int64)) [KernelArg] [Exp] [BlockDim] + | GetSize VName Name + | CmpSizeLe VName Name Exp + | GetSizeMax VName SizeClass + deriving (Show) + +instance Pretty HostOp where + pretty = pretty . show diff --git a/src/Futhark/CodeGen/ImpCode/OpenCL.hs b/src/Futhark/CodeGen/ImpCode/OpenCL.hs index 8a0fb7029e..6bc6219596 100644 --- a/src/Futhark/CodeGen/ImpCode/OpenCL.hs +++ b/src/Futhark/CodeGen/ImpCode/OpenCL.hs @@ -8,28 +8,14 @@ -- operation that allows one to execute an OpenCL kernel. module Futhark.CodeGen.ImpCode.OpenCL ( Program (..), - KernelName, - KernelArg (..), - CLCode, - OpenCL (..), - KernelSafety (..), - numFailureParams, KernelTarget (..), - FailureMsg (..), - BlockDim, - KernelConst (..), - KernelConstExp, - module Futhark.CodeGen.ImpCode, - module Futhark.IR.GPU.Sizes, + module Futhark.CodeGen.ImpCode.Kernels, ) where import Data.Map qualified as M import Data.Text qualified as T -import Futhark.CodeGen.ImpCode -import Futhark.CodeGen.ImpCode.GPU (BlockDim, KernelConst (..), KernelConstExp) -import Futhark.IR.GPU.Sizes -import Futhark.Util.Pretty +import Futhark.CodeGen.ImpCode.Kernels -- | An program calling OpenCL kernels. data Program = Program @@ -46,69 +32,12 @@ data Program = Program openClParams :: ParamMap, -- | Assertion failure error messages. openClFailures :: [FailureMsg], - hostDefinitions :: Definitions OpenCL + hostDefinitions :: Definitions HostOp } --- | Something that can go wrong in a kernel. Part of the machinery --- for reporting error messages from within kernels. -data FailureMsg = FailureMsg - { failureError :: ErrorMsg Exp, - failureBacktrace :: String - } - --- | A piece of code calling OpenCL. -type CLCode = Code OpenCL - --- | The name of a kernel. -type KernelName = Name - --- | An argument to be passed to a kernel. -data KernelArg - = -- | Pass the value of this scalar expression as argument. - ValueKArg Exp PrimType - | -- | Pass this pointer as argument. - MemKArg VName - deriving (Show) - --- | Whether a kernel can potentially fail (because it contains bounds --- checks and such). -data MayFail = MayFail | CannotFail - deriving (Show) - --- | Information about bounds checks and how sensitive it is to --- errors. Ordered by least demanding to most. -data KernelSafety - = -- | Does not need to know if we are in a failing state, and also - -- cannot fail. - SafetyNone - | -- | Needs to be told if there's a global failure, and that's it, - -- and cannot fail. - SafetyCheap - | -- | Needs all parameters, may fail itself. - SafetyFull - deriving (Eq, Ord, Show) - --- | How many leading failure arguments we must pass when launching a --- kernel with these safety characteristics. -numFailureParams :: KernelSafety -> Int -numFailureParams SafetyNone = 0 -numFailureParams SafetyCheap = 1 -numFailureParams SafetyFull = 3 - --- | Host-level OpenCL operation. -data OpenCL - = LaunchKernel KernelSafety KernelName (Count Bytes (TExp Int64)) [KernelArg] [Exp] [BlockDim] - | GetSize VName Name - | CmpSizeLe VName Name Exp - | GetSizeMax VName SizeClass - deriving (Show) - -- | The target platform when compiling imperative code to a 'Program' data KernelTarget = TargetOpenCL | TargetCUDA | TargetHIP deriving (Eq) - -instance Pretty OpenCL where - pretty = pretty . show diff --git a/src/Futhark/CodeGen/ImpCode/WebGPU.hs b/src/Futhark/CodeGen/ImpCode/WebGPU.hs new file mode 100644 index 0000000000..2d3da6bdda --- /dev/null +++ b/src/Futhark/CodeGen/ImpCode/WebGPU.hs @@ -0,0 +1,76 @@ +-- | Imperative code with a WebGPU component. +-- +-- Apart from ordinary imperative code, this also carries around a +-- WebGPU program as a string, as well as a list of kernels defined by +-- the program. +-- +-- The imperative code has been augmented with a 'LaunchKernel' +-- operation that allows one to execute a WebGPU kernel. +module Futhark.CodeGen.ImpCode.WebGPU + ( KernelInterface (..), + Program (..), + module Futhark.CodeGen.ImpCode.Kernels, + ) +where + +import Data.Map qualified as M +import Data.Text qualified as T +import Futhark.CodeGen.ImpCode.Kernels +import Futhark.Util.Pretty + +-- | The interface to a WebGPU/WGSL kernel. +-- +-- Arguments are assumed to be passed as shared memory sizes first, then +-- scalars, and then memory bindings. +data KernelInterface = KernelInterface + { safety :: KernelSafety, + -- | Offsets of all fields in the corresponding scalars struct. + scalarsOffsets :: [Int], + -- | Total size in bytes of the scalars uniform buffer. + scalarsSize :: Int, + -- | Bind slot index for the scalars uniform buffer. + scalarsBindSlot :: Int, + -- | Bind slot indices for all memory arguments. + memBindSlots :: [Int], + -- | Names of all the override declarations used by the kernel. Should only + -- be required for the ad-hoc WGSL testing setup, in normal code generation + -- these get passed through 'webgpuMacroDefs'. + -- Currently also used to work around a Chrome/Dawn bug, see + -- `gpu_create_kernel` in rts/c/backends/webgpu.h. + overrideNames :: [T.Text], + -- | Dynamic block dimensions, with the corresponding override name. They + -- are also included in `overrideNames`. + dynamicBlockDims :: [(Int, T.Text)], + -- | Override names for shared memory sizes. They are also included in + -- `overrideNames`. + sharedMemoryOverrides :: [T.Text], + -- | WGSL source of the kernel. This is only used for the built-in kernels. + -- The compiled wgsl kernel from futhark source is still stored in the gpu_program. + gpuProgram :: T.Text + } + +-- | A program calling WebGPU kernels. +data Program = Program + { webgpuProgram :: T.Text, + -- | Must be prepended to the program. + webgpuPrelude :: T.Text, + -- | Definitions to be passed as macro definitions to the kernel + -- compiler. + webgpuMacroDefs :: [(Name, KernelConstExp)], + webgpuKernels :: M.Map KernelName KernelInterface, + -- | Runtime-configurable constants. + webgpuParams :: ParamMap, + -- | Assertion failure error messages. + webgpuFailures :: [FailureMsg], + hostDefinitions :: Definitions HostOp + } + +instance Pretty Program where + pretty prog = + -- TODO: print everything + "webgpu {" + indent 2 (stack $ map pretty $ T.lines $ webgpuPrelude prog) + indent 2 (stack $ map pretty $ T.lines $ webgpuProgram prog) + "}" + "" + <> pretty (hostDefinitions prog) diff --git a/src/Futhark/CodeGen/ImpGen/GPU/SegRed.hs b/src/Futhark/CodeGen/ImpGen/GPU/SegRed.hs index ad6d7ed8d3..6817254f9c 100644 --- a/src/Futhark/CodeGen/ImpGen/GPU/SegRed.hs +++ b/src/Futhark/CodeGen/ImpGen/GPU/SegRed.hs @@ -952,7 +952,14 @@ reductionStageTwo segred_pes tblock_id segment_gtids first_block_for_segment blo sOp $ Imp.Barrier Imp.FenceGlobal is_last_block <- dPrim "is_last_block" - copyDWIMFix (tvVar is_last_block) [] (Var sync_arr) [0] + (sync_arr_mem, sync_arr_space, sync_arr_is) <- fullyIndexArray sync_arr [0] + sOp $ + Imp.UniformRead + (tvVar is_last_block) + sync_arr_mem + sync_arr_is + Bool + sync_arr_space sWhen (tvExp is_last_block) $ do -- The final block has written its result (and it was -- us!), so read in all the block results and perform the diff --git a/src/Futhark/CodeGen/ImpGen/GPU/ToOpenCL.hs b/src/Futhark/CodeGen/ImpGen/GPU/ToOpenCL.hs index 8692e5c030..6ae0e171c7 100644 --- a/src/Futhark/CodeGen/ImpGen/GPU/ToOpenCL.hs +++ b/src/Futhark/CodeGen/ImpGen/GPU/ToOpenCL.hs @@ -113,7 +113,7 @@ cleanSizes m = M.map clean m findParamUsers :: Env -> - Definitions ImpOpenCL.OpenCL -> + Definitions ImpOpenCL.HostOp -> M.Map Name SizeClass -> ParamMap findParamUsers env defs = M.mapWithKey onParam @@ -218,7 +218,7 @@ envFromProg prog = Env funs (funsMayFail cg funs) cg funs = defFuns prog cg = ImpGPU.callGraph calledInHostOp funs -lookupFunction :: Name -> Env -> Maybe (ImpGPU.Function HostOp) +lookupFunction :: Name -> Env -> Maybe (ImpGPU.Function ImpGPU.HostOp) lookupFunction fname = lookup fname . unFunctions . envFuns functionMayFail :: Name -> Env -> Bool @@ -230,8 +230,8 @@ addSize :: Name -> SizeClass -> OnKernelM () addSize key sclass = modify $ \s -> s {clSizes = M.insert key sclass $ clSizes s} -onHostOp :: KernelTarget -> HostOp -> OnKernelM OpenCL -onHostOp target (CallKernel k) = onKernel target k +onHostOp :: KernelTarget -> ImpGPU.HostOp -> OnKernelM ImpOpenCL.HostOp +onHostOp target (ImpGPU.CallKernel k) = onKernel target k onHostOp _ (ImpGPU.GetSize v key size_class) = do addSize key size_class pure $ ImpOpenCL.GetSize v key @@ -303,7 +303,7 @@ ensureDeviceFun fname host_func = do exists <- gets $ M.member fname . clDevFuns unless exists $ generateDeviceFun fname host_func -calledInHostOp :: HostOp -> S.Set Name +calledInHostOp :: ImpGPU.HostOp -> S.Set Name calledInHostOp (CallKernel k) = calledFuncs calledInKernelOp $ kernelBody k calledInHostOp _ = mempty @@ -327,7 +327,7 @@ ensureDeviceFuns code = do Nothing -> pure Nothing where bad = compilerLimitationS "Cannot generate GPU functions that contain parallelism." - toDevice :: HostOp -> KernelOp + toDevice :: ImpGPU.HostOp -> KernelOp toDevice _ = bad isConst :: BlockDim -> Maybe KernelConstExp @@ -337,7 +337,7 @@ isConst (Right e) = Just e isConst _ = Nothing -onKernel :: KernelTarget -> Kernel -> OnKernelM OpenCL +onKernel :: KernelTarget -> Kernel -> OnKernelM ImpOpenCL.HostOp onKernel target kernel = do called <- ensureDeviceFuns $ kernelBody kernel @@ -641,6 +641,8 @@ inKernelOperations env mode body = GC.modifyUserState $ \s -> s {kernelSharedMemory = (name', size) : kernelSharedMemory s} GC.stm [C.cstm|$id:name = (__local unsigned char*) $id:name';|] + kernelOps (UniformRead dest src i typ space) = + GC.compileCode (Read dest src i typ space Nonvolatile) kernelOps (ErrorSync f) = do label <- nextErrorLabel pending <- kernelSyncPending <$> GC.getUserState diff --git a/src/Futhark/CodeGen/ImpGen/WebGPU.hs b/src/Futhark/CodeGen/ImpGen/WebGPU.hs new file mode 100644 index 0000000000..99f8852057 --- /dev/null +++ b/src/Futhark/CodeGen/ImpGen/WebGPU.hs @@ -0,0 +1,1255 @@ +{-# LANGUAGE LambdaCase #-} + +-- | Code generation for ImpCode with WebGPU. +module Futhark.CodeGen.ImpGen.WebGPU + ( compileProg, + Warnings, + ) +where + +import Control.Monad (forM, forM_, liftM2, liftM3, unless, when) +import Control.Monad.Trans.Class +import Control.Monad.Trans.RWS +import Control.Monad.Trans.State qualified as State +import Data.Bifunctor (first, second) +import Data.Bits qualified as Bits +import Data.List qualified as L +import Data.Map qualified as M +import Data.Maybe (catMaybes, fromMaybe) +import Data.Set qualified as S +import Data.Text qualified as T +import Debug.Trace (traceM) +import Futhark.CodeGen.ImpCode.GPU qualified as ImpGPU +import Futhark.CodeGen.ImpCode.WebGPU +import Futhark.CodeGen.ImpGen.GPU qualified as ImpGPU +import Futhark.CodeGen.RTS.WGSL qualified as RTS +import Futhark.Error (compilerLimitation) +import Futhark.IR.GPUMem qualified as F +import Futhark.MonadFreshNames +import Futhark.Util (convFloat, nubOrd, zEncodeText) +import Futhark.Util.Pretty (align, docText, indent, pretty, ()) +import Language.Futhark.Warnings (Warnings) +import Language.WGSL qualified as WGSL + +-- State carried during WebGPU translation. +data WebGPUS = WebGPUS + { -- | Accumulated code. + wsCode :: T.Text, + wsSizes :: M.Map Name SizeClass, + wsMacroDefs :: [(Name, KernelConstExp)], + -- | Interface of kernels already generated into wsCode. + wsKernels :: [(WGSL.Ident, KernelInterface)], + wsNextBindSlot :: Int, + -- | TODO comment on this + wsDevFuns :: S.Set Name, + wsFuns :: ImpGPU.Functions ImpGPU.HostOp, + wsFunsMayFail :: S.Set Name + } + +-- The monad in which we perform the overall translation. +type WebGPUM = State.State WebGPUS + +addSize :: Name -> SizeClass -> WebGPUM () +addSize key sclass = + State.modify $ \s -> s {wsSizes = M.insert key sclass $ wsSizes s} + +addMacroDef :: Name -> KernelConstExp -> WebGPUM () +addMacroDef key e = + State.modify $ \s -> s {wsMacroDefs = (key, e) : wsMacroDefs s} + +addCode :: T.Text -> WebGPUM () +addCode code = + State.modify $ \s -> s {wsCode = wsCode s <> code} + +newtype KernelR = KernelR {krKernel :: ImpGPU.Kernel} + +data KernelState = KernelState + { -- Kernel declarations and body + ksDecls :: [WGSL.Declaration], + ksInits :: [WGSL.Stmt], + ksBody :: [WGSL.Stmt], + -- | Identifier replacement map. We have to rename some identifiers; when + -- translating Imp Code and PrimExps this map is consulted to respect the + -- renaming. + ksNameReplacements :: M.Map WGSL.Ident WGSL.Ident, + -- These describe the kernel interface. + ksOverrides :: [WGSL.Ident], + ksBlockDims :: [(Int, WGSL.Ident, Bool)], + -- TODO: Might be nice to combine sharedMem and atomicMem into some more + -- general information about memory in scope + ksSharedMem :: [(WGSL.Ident, Exp)], + ksAtomicMem :: [WGSL.Ident], + ksScalars :: [WGSL.PrimType], + ksBindSlots :: [Int] + } + +type KernelM = RWST KernelR () KernelState WebGPUM + +addRename :: WGSL.Ident -> WGSL.Ident -> KernelM () +addRename old new = modify $ + \s -> s {ksNameReplacements = M.insert old new (ksNameReplacements s)} + +-- | Some names generated are unique in the scope of a single kernel but are +-- translated to module-scope identifiers in WGSL. This modifies an identifier +-- to be unique in that scope. +mkGlobalIdent :: WGSL.Ident -> KernelM WGSL.Ident +mkGlobalIdent ident = do + kernelName <- asks (textToIdent . nameToText . ImpGPU.kernelName . krKernel) + pure $ kernelName <> "_" <> ident + +addDecl :: WGSL.Declaration -> KernelM () +addDecl decl = modify $ \s -> s {ksDecls = ksDecls s ++ [decl]} + +prependDecl :: WGSL.Declaration -> KernelM () +prependDecl decl = modify $ \s -> s {ksDecls = decl : ksDecls s} + +addInitStmt :: WGSL.Stmt -> KernelM () +addInitStmt stmt = modify $ \s -> s {ksInits = ksInits s ++ [stmt]} + +addBodyStmt :: WGSL.Stmt -> KernelM () +addBodyStmt stmt = modify $ \s -> s {ksBody = ksBody s ++ [stmt]} + +-- | Produces an identifier for the given name, respecting the name replacements +-- map. +getIdent :: (F.Pretty a) => a -> KernelM WGSL.Ident +getIdent name = gets (M.findWithDefault t t . ksNameReplacements) + where + t = zEncodeText $ prettyText name + +-- | Get a new, unused binding index and add it to the list of bind slots used +-- by the current kernel. +assignBindSlot :: KernelM Int +assignBindSlot = do + wState <- lift State.get + let slot = wsNextBindSlot wState + modify $ \s -> s {ksBindSlots = ksBindSlots s ++ [slot]} + lift $ State.put (wState {wsNextBindSlot = slot + 1}) + pure slot + +-- | Add an override declaration to the current kernel's interface and into the +-- module. +addOverride :: WGSL.Ident -> WGSL.Typ -> Maybe WGSL.Exp -> KernelM () +addOverride ident typ e = do + addDecl $ WGSL.OverrideDecl ident typ e + modify $ \s -> s {ksOverrides = ksOverrides s ++ [ident]} + +-- | Register an override identifier as describing the given dimension of the +-- block size of the current kernel. +addBlockDim :: Int -> WGSL.Ident -> Bool -> KernelM () +addBlockDim dim ident dynamic = + modify $ \s -> s {ksBlockDims = (dim, ident, dynamic) : ksBlockDims s} + +-- | Registers an override identifier as describing the size of a shared memory +-- buffer, with the expression being evaluated to get the size when launching +-- the kernel. +addSharedMem :: WGSL.Ident -> Exp -> KernelM () +addSharedMem ident e = + modify $ \s -> s {ksSharedMem = (ident, e) : ksSharedMem s} + +addAtomicMem :: WGSL.Ident -> KernelM () +addAtomicMem ident = modify $ \s -> s {ksAtomicMem = ident : ksAtomicMem s} + +-- | Whether the identifier is the name of a shared memory allocation. +-- TODO: Should probably store the allocation name in the state instead of +-- reconstructing the _size name here. +isShared :: WGSL.Ident -> KernelM Bool +isShared ident = any (\(sz, _) -> sz == ident <> "_size") <$> gets ksSharedMem + +isAtomic :: WGSL.Ident -> KernelM Bool +isAtomic ident = elem ident <$> gets ksAtomicMem + +-- | Add a scalar struct field. +addScalar :: WGSL.PrimType -> KernelM () +addScalar typ = modify $ \s -> s {ksScalars = ksScalars s ++ [typ]} + +entryParams :: [WGSL.Param] +entryParams = + [ WGSL.Param + "workgroup_id" + (WGSL.Prim (WGSL.Vec3 WGSL.UInt32)) + [WGSL.Attrib "builtin" [WGSL.VarExp "workgroup_id"]], + WGSL.Param + "local_id" + (WGSL.Prim (WGSL.Vec3 WGSL.UInt32)) + [WGSL.Attrib "builtin" [WGSL.VarExp "local_invocation_id"]] + ] + +builtinLockstepWidth :: KernelM WGSL.Ident +builtinLockstepWidth = mkGlobalIdent "lockstep_width" + +builtinBlockSize :: Int -> KernelM WGSL.Ident +builtinBlockSize 0 = mkGlobalIdent "block_size_x" +builtinBlockSize 1 = mkGlobalIdent "block_size_y" +builtinBlockSize 2 = mkGlobalIdent "block_size_z" +builtinBlockSize _ = error "invalid block size dimension" + +-- Main function for translating an ImpGPU kernel to a WebGPU kernel. +genKernel :: ImpGPU.Kernel -> WebGPUM (KernelName, [(Exp, PrimType)]) +genKernel kernel = do + let initial = + KernelState + { ksDecls = mempty, + ksInits = mempty, + ksBody = mempty, + ksNameReplacements = mempty, + ksOverrides = mempty, + ksBlockDims = mempty, + ksAtomicMem = mempty, + ksSharedMem = mempty, + ksScalars = mempty, + ksBindSlots = mempty + } + + ((), s, ()) <- runRWST gen (KernelR kernel) initial + + addCode $ docText $ WGSL.prettyDecls (ksDecls s) + addCode "\n\n" + + let name = nameToText $ ImpGPU.kernelName kernel + let blockDimNames = [n | (_, n, _) <- ksBlockDims s] + let attribs = + [ WGSL.Attrib "compute" [], + WGSL.Attrib "workgroup_size" (map WGSL.VarExp blockDimNames) + ] + let wgslFun = + WGSL.Function + { WGSL.funName = textToIdent name, + WGSL.funAttribs = attribs, + WGSL.funParams = entryParams, + WGSL.funOutput = [], + WGSL.funBody = WGSL.stmts (ksInits s ++ ksBody s) + } + addCode $ prettyText wgslFun + + let (offsets, _align, size) = + -- dummy layout with single i32 instead of empty structs + fromMaybe ([], 4, 4) (WGSL.structLayout (ksScalars s)) + let dynamicBlockDims = [(dim, n) | (dim, n, True) <- ksBlockDims s] + let (sharedMemOverrides, sharedMemExps) = unzip $ ksSharedMem s + let interface = + KernelInterface + { safety = SafetyNone, -- TODO + scalarsOffsets = offsets, + scalarsSize = size, + scalarsBindSlot = head (ksBindSlots s), + memBindSlots = tail (ksBindSlots s), + overrideNames = ksOverrides s, + dynamicBlockDims = dynamicBlockDims, + sharedMemoryOverrides = sharedMemOverrides, + gpuProgram = T.empty + } + State.modify $ \ws -> ws {wsKernels = wsKernels ws <> [(name, interface)]} + pure (nameFromText name, map (,IntType Int32) sharedMemExps) + where + gen = do + genConstAndBuiltinDecls + genScalarDecls + genMemoryDecls + + genDeviceFuns $ ImpGPU.kernelBody kernel + + -- FIXME: This is only required to work around a Chrome bug that otherwise + -- causes the shader to fail compilation if the kernel never accesses the + -- lockstep width. See `gpu_create_kernel` in `rts/c/backends/webgpu.h` + -- for more details. + lsw <- builtinLockstepWidth + addInitStmt $ + WGSL.Let "_dummy_lockstep_width" (WGSL.VarExp lsw) + + wgslBody <- genWGSLStm (ImpGPU.kernelBody kernel) + addBodyStmt wgslBody + + pure () + +calledInKernelOp :: ImpGPU.KernelOp -> S.Set Name +calledInKernelOp = const mempty + +lookupFunction :: Name -> WebGPUS -> Maybe (ImpGPU.Function ImpGPU.HostOp) +lookupFunction fname = lookup fname . unFunctions . wsFuns + +functionMayFail :: Name -> WebGPUS -> Bool +functionMayFail fname = S.member fname . wsFunsMayFail + +genFunParams :: [Param] -> KernelM [WGSL.Param] +genFunParams = + mapM $ \case + MemParam _ _ -> compilerLimitation "WebGPU backend cannot handle GPU functions with memory parameters." + ScalarParam name tp -> do + ident <- getIdent name + pure $ WGSL.Param ident (WGSL.Prim $ wgslPrimType tp) [] + +generateDeviceFun :: Name -> ImpGPU.Function ImpGPU.KernelOp -> KernelM () +generateDeviceFun fname device_func = do + when (any memParam $ functionInput device_func) $ + compilerLimitation "WebGPU backend cannot generate GPU functions that use arrays." + ws <- lift State.get + ks <- get + r <- ask + (body, _, _) <- lift $ runRWST (genWGSLStm (functionBody device_func)) r ks + + if functionMayFail fname ws + then compilerLimitation "WebGPU backend Cannot handle GPU functions that may fail." + else do + in_params <- genFunParams (functionInput device_func) + out_params <- genFunParams (functionOutput device_func) + let out_ptr_params = + map + ( \(WGSL.Param name (WGSL.Prim t) a) -> + WGSL.Param (name <> "_out") (WGSL.Pointer t WGSL.FunctionSpace Nothing) a + ) + out_params + let wgslFun = + WGSL.Function + { WGSL.funName = "futrts_" <> nameToText fname, + WGSL.funAttribs = mempty, + WGSL.funParams = in_params, + WGSL.funOutput = out_ptr_params, + WGSL.funBody = body + } + in prependDecl $ WGSL.FunDecl wgslFun + + lift $ State.modify $ \s -> + s + { wsDevFuns = S.insert fname $ wsDevFuns s + } + + genDeviceFuns $ functionBody device_func + where + memParam MemParam {} = True + memParam ScalarParam {} = False + +-- Ensure that this device function is available, but don't regenerate +-- it if it already exists. +ensureDeviceFun :: Name -> ImpGPU.Function ImpGPU.KernelOp -> KernelM () +ensureDeviceFun fname host_func = do + exists <- lift $ State.gets $ S.member fname . wsDevFuns + unless exists $ generateDeviceFun fname host_func + +genDeviceFuns :: ImpGPU.KernelCode -> KernelM () +genDeviceFuns code = do + let called = calledFuncs calledInKernelOp code + forM_ (S.toList called) $ \fname -> do + def <- lift $ State.gets $ lookupFunction fname + case def of + Just host_func -> do + let device_func = fmap toDevice host_func + ensureDeviceFun fname device_func + Nothing -> pure () + where + toDevice :: ImpGPU.HostOp -> ImpGPU.KernelOp + toDevice _ = compilerLimitation "WebGPU backend cannot handle GPU functions that contain parallelism." + +onKernel :: ImpGPU.Kernel -> WebGPUM HostOp +onKernel kernel = do + (name, extraArgExps) <- genKernel kernel + let numBlocks = ImpGPU.kernelNumBlocks kernel + let blockDim = ImpGPU.kernelBlockSize kernel + let extraArgs = [ValueKArg e t | (e, t) <- extraArgExps] + let scalarArgs = + [ ValueKArg (LeafExp n t) t + | ImpGPU.ScalarUse n t <- ImpGPU.kernelUses kernel + ] + let memArgs = [MemKArg n | ImpGPU.MemoryUse n <- ImpGPU.kernelUses kernel] + let args = extraArgs ++ scalarArgs ++ memArgs + + pure $ LaunchKernel SafetyNone name 0 args numBlocks blockDim + +onHostOp :: ImpGPU.HostOp -> WebGPUM HostOp +onHostOp (ImpGPU.CallKernel k) = onKernel k +onHostOp (ImpGPU.GetSize v key size_class) = do + addSize key size_class + pure $ GetSize v key +onHostOp (ImpGPU.CmpSizeLe v key size_class x) = do + addSize key size_class + pure $ CmpSizeLe v key x +onHostOp (ImpGPU.GetSizeMax v size_class) = + pure $ GetSizeMax v size_class + +-- | Generate WebGPU host and device code. +kernelsToWebGPU :: ImpGPU.Program -> Program +kernelsToWebGPU prog = + let ImpGPU.Definitions + types + (ImpGPU.Constants ps consts) + (ImpGPU.Functions funs) = prog + + initial_state = + WebGPUS + { wsCode = mempty, + wsSizes = mempty, + wsMacroDefs = mempty, + wsKernels = mempty, + wsNextBindSlot = 0, + wsDevFuns = mempty, + wsFuns = defFuns prog, + wsFunsMayFail = S.empty + } + + ((consts', funs'), translation) = + flip State.runState initial_state $ + (,) <$> traverse onHostOp consts <*> traverse (traverse (traverse onHostOp)) funs + + prog' = + Definitions types (Constants ps consts') (Functions funs') + + kernels = M.fromList $ map (first nameFromText) (wsKernels translation) + constants = wsMacroDefs translation + -- TODO: Compute functions using tuning params + params = M.map (,S.empty) $ wsSizes translation + failures = mempty + in Program + { webgpuProgram = wsCode translation, + webgpuPrelude = RTS.wgsl_prelude, + webgpuMacroDefs = constants, + webgpuKernels = kernels, + webgpuParams = params, + webgpuFailures = failures, + hostDefinitions = prog' + } + +-- | Compile the program to ImpCode with WebGPU kernels. +compileProg :: (MonadFreshNames m) => F.Prog F.GPUMem -> m (Warnings, Program) +compileProg prog = second kernelsToWebGPU <$> ImpGPU.compileProgOpenCL prog + +wgslInt8, wgslInt16, wgslInt64 :: WGSL.PrimType +wgslInt8 = WGSL.Int32 +wgslInt16 = WGSL.Int32 +wgslInt64 = WGSL.Vec2 WGSL.Int32 + +wgslPrimType :: PrimType -> WGSL.PrimType +wgslPrimType (IntType Int8) = wgslInt8 +wgslPrimType (IntType Int16) = wgslInt16 +wgslPrimType (IntType Int32) = WGSL.Int32 +wgslPrimType (IntType Int64) = wgslInt64 +wgslPrimType (FloatType Float16) = WGSL.Float16 +wgslPrimType (FloatType Float32) = WGSL.Float32 +wgslPrimType (FloatType Float64) = compilerLimitation "WebGPU backend does not support f64." +wgslPrimType Bool = WGSL.Bool +-- TODO: Make sure we do not ever codegen statements involving Unit variables +wgslPrimType Unit = WGSL.Float16 -- error "TODO: no unit in WGSL" + +wgslBufferType :: (PrimType, Bool, Signedness) -> Maybe WGSL.Exp -> WGSL.Typ +wgslBufferType (Bool, _, _) = WGSL.Array $ WGSL.Atomic wgslInt8 +wgslBufferType (IntType Int8, _, _) = WGSL.Array $ WGSL.Atomic wgslInt8 +wgslBufferType (IntType Int16, _, _) = WGSL.Array $ WGSL.Atomic wgslInt16 +wgslBufferType (IntType Int32, False, _) = WGSL.Array WGSL.Int32 +wgslBufferType (IntType Int32, True, _) = + WGSL.Array $ WGSL.Atomic WGSL.Int32 +wgslBufferType (FloatType Float32, True, _) = + WGSL.Array $ WGSL.Atomic WGSL.Int32 +wgslBufferType (FloatType Float16, True, _) = + WGSL.Array $ WGSL.Atomic WGSL.Int32 +wgslBufferType (t, _, _) = WGSL.Array $ wgslPrimType t + +wgslSharedBufferType :: + (PrimType, Bool, Signedness) -> + Maybe WGSL.Exp -> + WGSL.Typ +wgslSharedBufferType (Bool, _, _) = WGSL.Array WGSL.Bool +wgslSharedBufferType (IntType Int8, True, _) = WGSL.Array $ WGSL.Atomic WGSL.Int32 +wgslSharedBufferType (IntType Int16, True, _) = WGSL.Array $ WGSL.Atomic WGSL.Int32 +wgslSharedBufferType (IntType Int8, False, _) = WGSL.Array WGSL.Int32 +wgslSharedBufferType (IntType Int16, False, _) = WGSL.Array WGSL.Int32 +wgslSharedBufferType (IntType Int32, False, _) = WGSL.Array WGSL.Int32 +wgslSharedBufferType (IntType Int32, True, Signed) = + WGSL.Array $ WGSL.Atomic WGSL.Int32 +wgslSharedBufferType (IntType Int32, True, Unsigned) = + WGSL.Array $ WGSL.Atomic WGSL.UInt32 +wgslSharedBufferType (FloatType Float32, True, _) = + WGSL.Array $ WGSL.Atomic WGSL.Int32 +wgslSharedBufferType (FloatType Float16, True, _) = + WGSL.Array $ WGSL.Atomic WGSL.Int32 +wgslSharedBufferType (t, _, _) = WGSL.Array $ wgslPrimType t + +packedElemIndex :: PrimType -> WGSL.Exp -> WGSL.Exp +packedElemIndex Bool i = WGSL.BinOpExp "/" i (WGSL.IntExp 4) +packedElemIndex (IntType Int8) i = WGSL.BinOpExp "/" i (WGSL.IntExp 4) +packedElemIndex (IntType Int16) i = WGSL.BinOpExp "/" i (WGSL.IntExp 2) +packedElemIndex (IntType Int32) i = i +packedElemIndex (FloatType Float16) i = WGSL.BinOpExp "/" i (WGSL.IntExp 2) +packedElemIndex (FloatType Float32) i = i +packedElemIndex _ _ = error "CodeGen.ImpGen.WebGPU:packedElemIndex: Unsupported Type" + +packedElemOffset :: PrimType -> WGSL.Exp -> WGSL.Exp +packedElemOffset Bool i = WGSL.BinOpExp "%" i (WGSL.IntExp 4) +packedElemOffset (IntType Int8) i = WGSL.BinOpExp "%" i (WGSL.IntExp 4) +packedElemOffset (IntType Int16) i = WGSL.BinOpExp "%" i (WGSL.IntExp 2) +packedElemOffset (IntType Int32) _ = WGSL.IntExp 0 +packedElemOffset (FloatType Float16) i = WGSL.BinOpExp "%" i (WGSL.IntExp 2) +packedElemOffset (FloatType Float32) _ = WGSL.IntExp 0 +packedElemOffset _ _ = error "CodeGen.ImpGen.WebGPU:packedElemOffset: Unsupported Type" + +nativeAccessType :: PrimType -> Bool +nativeAccessType (IntType Int64) = True +nativeAccessType (IntType Int32) = True +nativeAccessType (FloatType Float32) = True +nativeAccessType (FloatType Float16) = True +nativeAccessType _ = False + +nativeAccessSpace :: Space -> Bool +nativeAccessSpace (ScalarSpace _ _) = True +nativeAccessSpace (Space "shared") = True +nativeAccessSpace _ = False + +genArrayAccess :: + Bool -> + PrimType -> + WGSL.Ident -> + WGSL.Exp -> + Bool -> + [WGSL.Exp] +genArrayAccess atomic t mem i packed = do + if not atomic + then [WGSL.UnOpExp "&" (WGSL.VarExp mem), i] + else + if packed + then + let packedIndex = packedElemIndex t i + packedOffset = packedElemOffset t i + in [WGSL.UnOpExp "&" (WGSL.IndexExp mem packedIndex), packedOffset] + else [WGSL.UnOpExp "&" (WGSL.IndexExp mem i)] + +genArrayFun :: WGSL.Ident -> PrimType -> Bool -> Bool -> WGSL.Ident +genArrayFun fun t atomic shared = + if atomic + then + let scope = if shared then "_shared" else "_global" + prefix = if atomic then "atomic_" else "" + in prefix <> fun <> "_" <> prettyText t <> scope + else fun <> "_" <> prettyText t + +genReadExp :: + PrimType -> + Space -> + VName -> + Count Elements (TExp Int64) -> + KernelM WGSL.Exp +genReadExp t s mem i = do + mem' <- getIdent mem + shared <- isShared mem' + atomic <- isAtomic mem' + i' <- indexExp i + + if (nativeAccessSpace s || nativeAccessType t || shared) && not atomic + then pure $ WGSL.IndexExp mem' i' + else + let access = genArrayAccess atomic t mem' i' (not shared) + fun = genArrayFun "read" t atomic shared + in pure $ WGSL.CallExp fun access + +genArrayRead :: + PrimType -> + Space -> + VName -> + VName -> + Count Elements (TExp Int64) -> + KernelM WGSL.Stmt +genArrayRead t s tgt mem i = do + tgt' <- getIdent tgt + read' <- genReadExp t s mem i + pure $ WGSL.Assign tgt' read' + +genArrayWrite :: + PrimType -> + Space -> + VName -> + Count Elements (TExp Int64) -> + KernelM WGSL.Exp -> + KernelM WGSL.Stmt +genArrayWrite t s mem i v = do + mem' <- getIdent mem + shared <- isShared mem' + atomic <- isAtomic mem' + i' <- indexExp i + v' <- v + + if (nativeAccessSpace s || nativeAccessType t || shared) && not atomic + then pure $ WGSL.AssignIndex mem' i' v' + else + let access = genArrayAccess atomic t mem' i' (not shared) ++ [v'] + fun = genArrayFun "write" t atomic shared + in pure $ WGSL.Call fun access + +genAtomicOp :: + WGSL.Ident -> + PrimType -> + VName -> + VName -> + Count Elements (TExp Int64) -> + Exp -> + KernelM WGSL.Stmt +genAtomicOp f t dest mem i e = do + mem' <- getIdent mem + shared <- isShared mem' + i' <- indexExp i + v' <- genWGSLExp e + + let fun = genArrayFun f t True shared + args = genArrayAccess True t mem' i' (not shared) ++ [v'] + in WGSL.Assign <$> getIdent dest <*> pure (WGSL.CallExp fun args) + +genCopy :: + PrimType -> + [Count Elements (TExp Int64)] -> + (VName, Space) -> + ( Count Elements (TExp Int64), + [Count Elements (TExp Int64)] + ) -> + (VName, Space) -> + ( Count Elements (TExp Int64), + [Count Elements (TExp Int64)] + ) -> + KernelM WGSL.Stmt +genCopy pt shape (dst, dst_space) (dst_offset, dst_strides) (src, src_space) (src_offset, src_strides) = do + shape' <- mapM (genWGSLExp . untyped . unCount) shape + body <- do + traceM $ "Generating copy for " <> T.unpack (prettyText dst) <> " <- " <> T.unpack (prettyText src) <> " dst space is " <> T.unpack (prettyText dst_space) + let dst_i = dst_offset + sum (zipWith (*) is' dst_strides) + src_i = src_offset + sum (zipWith (*) is' src_strides) + read' = genReadExp pt src_space src src_i + in genArrayWrite pt dst_space dst dst_i read' + + pure $ loops (zip iis shape') body + where + (zero, one) = (WGSL.VarExp "zero_i64", WGSL.VarExp "one_i64") + is = map (VName "i") [0 .. length shape - 1] + is' :: [Count Elements (TExp Int64)] + is' = map (elements . le64) is + iis = map nameToIdent is + + loops :: [(WGSL.Ident, WGSL.Exp)] -> WGSL.Stmt -> WGSL.Stmt + loops [] body = body + loops ((i, n) : ins) body = + WGSL.For + i + zero + (wgslCmpOp (CmpUlt Int64) (WGSL.VarExp i) n) + (WGSL.Assign i $ wgslBinOp (Add Int64 OverflowWrap) (WGSL.VarExp i) one) + (loops ins body) + +unsupported :: Code ImpGPU.KernelOp -> KernelM WGSL.Stmt +unsupported stmt = pure $ WGSL.Comment $ "Unsupported stmt: " <> prettyText stmt + +wgslProduct :: [SubExp] -> WGSL.Exp +wgslProduct [] = WGSL.IntExp 1 +wgslProduct [Constant (IntValue v)] = WGSL.IntExp $ valueIntegral v +wgslProduct ((Constant (IntValue v)) : vs) = + wgslBinOp (Mul Int32 OverflowWrap) (WGSL.IntExp $ valueIntegral v) (wgslProduct vs) +wgslProduct _ = error "wgslProduct: non-constant product" + +genWGSLStm :: Code ImpGPU.KernelOp -> KernelM WGSL.Stmt +genWGSLStm Skip = pure WGSL.Skip +genWGSLStm (s1 :>>: s2) = liftM2 WGSL.Seq (genWGSLStm s1) (genWGSLStm s2) +genWGSLStm (For iName bound body) = do + boundExp <- genWGSLExp bound + bodyStm <- genWGSLStm body + pure $ + WGSL.For + i + zero + (lt (WGSL.VarExp i) boundExp) + (WGSL.Assign i $ add (WGSL.VarExp i) one) + bodyStm + where + i = nameToIdent iName + boundIntType = case primExpType bound of + IntType t -> t + _ -> error "non-integer Exp for loop bound" + add = wgslBinOp $ Add boundIntType OverflowWrap + lt = wgslCmpOp $ CmpUlt boundIntType + (zero, one) = case boundIntType of + Int64 -> (WGSL.VarExp "zero_i64", WGSL.VarExp "one_i64") + _ -> (WGSL.IntExp 0, WGSL.IntExp 1) +genWGSLStm (While cond body) = + liftM2 + WGSL.While + (genWGSLExp $ untyped cond) + (genWGSLStm body) +genWGSLStm (DeclareMem name (Space "shared")) = do + let name' = nameToIdent name + moduleName <- mkGlobalIdent name' + sizeName <- mkGlobalIdent $ name' <> "_size" + + maybeElemPrimType <- findSingleMemoryType name + case maybeElemPrimType of + Just elemPrimType@(_, atomic, _) -> do + let bufType = + wgslSharedBufferType + elemPrimType + (Just $ WGSL.VarExp sizeName) + + addOverride sizeName (WGSL.Prim WGSL.Int32) (Just $ WGSL.IntExp 0) + addDecl $ WGSL.VarDecl [] WGSL.Workgroup moduleName bufType + when atomic $ addAtomicMem moduleName + addRename name' moduleName + pure $ WGSL.Comment $ "declare_shared: " <> name' + Nothing -> + pure $ WGSL.Comment $ "discard declare_shared: " <> name' +genWGSLStm (DeclareMem name (ScalarSpace vs pt)) = + pure $ + WGSL.DeclareVar (nameToIdent name) (WGSL.Array (wgslPrimType pt) (Just $ wgslProduct vs)) +genWGSLStm s@(DeclareMem _ _) = unsupported s +genWGSLStm (DeclareScalar name _ typ) = + pure $ + WGSL.DeclareVar (nameToIdent name) (WGSL.Prim $ wgslPrimType typ) +genWGSLStm s@(DeclareArray {}) = unsupported s +genWGSLStm s@(Allocate {}) = unsupported s +genWGSLStm s@(Free _ _) = pure $ WGSL.Comment $ "free: " <> prettyText s +genWGSLStm (Copy pt shape dst dst_lmad src src_lmad) = genCopy pt shape dst dst_lmad src src_lmad +genWGSLStm (Write mem i Bool s _ v) = genArrayWrite Bool s mem i (genWGSLExp v) +genWGSLStm (Write mem i (IntType Int8) s _ v) = genArrayWrite (IntType Int8) s mem i (genWGSLExp v) +genWGSLStm (Write mem i (IntType Int16) s _ v) = genArrayWrite (IntType Int16) s mem i (genWGSLExp v) +genWGSLStm (Write mem i t s _ v) = do + mem' <- getIdent mem + i' <- indexExp i + v' <- genWGSLExp v + atomic <- isAtomic mem' + if atomic + then genArrayWrite t s mem i (genWGSLExp v) + else pure $ WGSL.AssignIndex mem' i' v' +genWGSLStm (SetScalar name e) = + liftM2 WGSL.Assign (getIdent name) (genWGSLExp e) +genWGSLStm (Read tgt mem i Bool s _) = genArrayRead Bool s tgt mem i +genWGSLStm (Read tgt mem i (IntType Int8) s _) = genArrayRead (IntType Int8) s tgt mem i +genWGSLStm (Read tgt mem i (IntType Int16) s _) = genArrayRead (IntType Int16) s tgt mem i +genWGSLStm (Read tgt mem i t s _) = do + tgt' <- getIdent tgt + mem' <- getIdent mem + i' <- indexExp i + atomic <- isAtomic mem' + if atomic + then genArrayRead t s tgt mem i + else pure $ WGSL.Assign tgt' (WGSL.IndexExp mem' i') +genWGSLStm stm@(SetMem {}) = + compilerLimitation . docText $ + "WebGPU backend Cannot handle SetMem statement" + indent 2 (align (pretty stm)) + "in GPU kernel." +genWGSLStm (Call [dest] f args) = do + fun <- WGSL.CallExp . ("futrts_" <>) <$> getIdent f + let getArg (ExpArg e) = genWGSLExp e + getArg (MemArg n) = WGSL.VarExp <$> getIdent n + argExps <- mapM getArg args + WGSL.Assign <$> getIdent dest <*> pure (fun argExps) +genWGSLStm (Call dests f args) = do + fname <- ("futrts_" <>) <$> getIdent f + outIdents <- mapM getIdent dests + + let getArg (ExpArg e) = genWGSLExp e + getArg (MemArg n) = WGSL.VarExp <$> getIdent n + + argExps <- mapM getArg args + outExps <- mapM (pure . WGSL.UnOpExp "&" . WGSL.VarExp) outIdents + + pure $ WGSL.Call fname (argExps ++ outExps) +genWGSLStm (If cond cThen cElse) = + liftM3 + WGSL.If + (genWGSLExp $ untyped cond) + (genWGSLStm cThen) + (genWGSLStm cElse) +genWGSLStm s@(Assert {}) = pure $ WGSL.Comment $ "assert: " <> prettyText s +genWGSLStm (Meta (MetaComment c)) = + pure $ WGSL.Comment c +genWGSLStm (Meta (MetaProvenance (Provenance _ l))) = + pure $ WGSL.Comment $ locText l +genWGSLStm (DebugPrint _ _) = pure WGSL.Skip +genWGSLStm (TracePrint _) = pure WGSL.Skip +genWGSLStm (Op (ImpGPU.GetBlockId dest i)) = do + destId <- getIdent dest + pure $ + WGSL.Assign destId $ + WGSL.to_i32 (WGSL.IndexExp "workgroup_id" (WGSL.IntExp i)) +genWGSLStm (Op (ImpGPU.GetLocalId dest i)) = do + destId <- getIdent dest + pure $ + WGSL.Assign destId $ + WGSL.to_i32 (WGSL.IndexExp "local_id" (WGSL.IntExp i)) +genWGSLStm (Op (ImpGPU.GetLocalSize dest i)) = do + destId <- getIdent dest + WGSL.Assign destId . WGSL.VarExp <$> builtinBlockSize i +genWGSLStm (Op (ImpGPU.GetLockstepWidth dest)) = do + destId <- getIdent dest + WGSL.Assign destId . WGSL.VarExp <$> builtinLockstepWidth +genWGSLStm (Op (ImpGPU.Atomic _ (ImpGPU.AtomicAdd t dest mem i e))) = + genAtomicOp "add" (IntType t) dest mem i e +genWGSLStm (Op (ImpGPU.Atomic _ (ImpGPU.AtomicSMax t dest mem i e))) = + genAtomicOp "smax" (IntType t) dest mem i e +genWGSLStm (Op (ImpGPU.Atomic _ (ImpGPU.AtomicSMin t dest mem i e))) = + genAtomicOp "smin" (IntType t) dest mem i e +genWGSLStm (Op (ImpGPU.Atomic _ (ImpGPU.AtomicUMax t dest mem i e))) = + genAtomicOp "umax" (IntType t) dest mem i e +genWGSLStm (Op (ImpGPU.Atomic _ (ImpGPU.AtomicUMin t dest mem i e))) = + genAtomicOp "umin" (IntType t) dest mem i e +genWGSLStm (Op (ImpGPU.Atomic _ (ImpGPU.AtomicAnd t dest mem i e))) = + genAtomicOp "and" (IntType t) dest mem i e +genWGSLStm (Op (ImpGPU.Atomic _ (ImpGPU.AtomicOr t dest mem i e))) = + genAtomicOp "or" (IntType t) dest mem i e +genWGSLStm (Op (ImpGPU.Atomic _ (ImpGPU.AtomicXor t dest mem i e))) = + genAtomicOp "xor" (IntType t) dest mem i e +genWGSLStm (Op (ImpGPU.Atomic _ (ImpGPU.AtomicCmpXchg _ dest mem i cmp val))) = do + val' <- genWGSLExp val + cmp' <- genWGSLExp cmp + i' <- WGSL.IndexExp <$> getIdent mem <*> indexExp i + liftM2 WGSL.Assign (getIdent dest) (pure $ WGSL.FieldExp (WGSL.CallExp "atomicCompareExchangeWeak" [WGSL.UnOpExp "&" i', cmp', val']) "old_value") +genWGSLStm (Op (ImpGPU.Atomic _ (ImpGPU.AtomicXchg _ dest mem i e))) = do + idx <- WGSL.IndexExp <$> getIdent mem <*> indexExp i + val <- genWGSLExp e + let call = WGSL.CallExp "atomicExchange" [WGSL.UnOpExp "&" idx, val] + WGSL.Assign <$> getIdent dest <*> pure call +genWGSLStm (Op (ImpGPU.Atomic s (ImpGPU.AtomicWrite t mem i v))) = genArrayWrite t s mem i (genWGSLExp v) +genWGSLStm (Op (ImpGPU.Atomic _ (ImpGPU.AtomicFAdd t dest mem i e))) = genAtomicOp "fadd" (FloatType t) dest mem i e +genWGSLStm (Op (ImpGPU.Barrier ImpGPU.FenceLocal)) = + pure $ WGSL.Call "workgroupBarrier" [] +genWGSLStm (Op (ImpGPU.Barrier ImpGPU.FenceGlobal)) = + pure $ WGSL.Call "storageBarrier" [] +genWGSLStm s@(Op (ImpGPU.MemFence _)) = unsupported s +genWGSLStm (Op (ImpGPU.SharedAlloc name size)) = do + let name' = nameToIdent name + sizeName <- mkGlobalIdent $ name' <> "_size" + + maybeElemPrimType <- findSingleMemoryType name + case maybeElemPrimType of + Just (elemPrimType, _, _) -> do + let elemSize = primByteSize elemPrimType :: Int32 + let sizeExp = + zExt Int32 (untyped (unCount size)) + ~/~ ValueExp (IntValue $ Int32Value elemSize) + + addSharedMem sizeName sizeExp + pure $ WGSL.Comment $ "shared_alloc: " <> name' + Nothing -> + pure $ WGSL.Comment $ "discard shared_alloc: " <> name' +genWGSLStm (Op (ImpGPU.UniformRead tgt mem i _ _)) = do + tgt' <- getIdent tgt + mem' <- getIdent mem + i' <- indexExp i + pure $ + WGSL.Assign tgt' $ + WGSL.CallExp "workgroupUniformLoad" [WGSL.UnOpExp "&" $ WGSL.IndexExp mem' i'] +genWGSLStm (Op (ImpGPU.ErrorSync f)) = genWGSLStm $ Op (ImpGPU.Barrier f) + +call1 :: WGSL.Ident -> WGSL.Exp -> WGSL.Exp +call1 f a = WGSL.CallExp f [a] + +call2 :: WGSL.Ident -> WGSL.Exp -> WGSL.Exp -> WGSL.Exp +call2 f a b = WGSL.CallExp f [a, b] + +call2Suffix :: WGSL.Ident -> IntType -> WGSL.Exp -> WGSL.Exp -> WGSL.Exp +call2Suffix f t a b = WGSL.CallExp (f <> "_" <> prettyText t) [a, b] + +wgslBinOp :: BinOp -> WGSL.Exp -> WGSL.Exp -> WGSL.Exp +wgslBinOp (Add Int32 _) = WGSL.BinOpExp "+" +wgslBinOp (Add t _) = call2Suffix "add" t +wgslBinOp (FAdd _) = WGSL.BinOpExp "+" +wgslBinOp (Sub Int32 _) = WGSL.BinOpExp "-" +wgslBinOp (Sub t _) = call2Suffix "sub" t +wgslBinOp (FSub _) = WGSL.BinOpExp "-" +wgslBinOp (Mul Int32 _) = WGSL.BinOpExp "*" +wgslBinOp (Mul t _) = call2Suffix "mul" t +wgslBinOp (FMul _) = WGSL.BinOpExp "*" +-- Division is always safe in WGSL, so we can ignore the Safety parameter. +wgslBinOp (UDiv t _) = call2Suffix "udiv" t +wgslBinOp (UDivUp t _) = call2Suffix "udiv_up" t +wgslBinOp (SDiv t _) = call2Suffix "sdiv" t +wgslBinOp (SDivUp t _) = call2Suffix "sdiv_up" t +wgslBinOp (FDiv _) = WGSL.BinOpExp "/" +wgslBinOp (FMod _) = WGSL.BinOpExp "%" +wgslBinOp (UMod t _) = call2Suffix "umod" t +wgslBinOp (SMod t _) = call2Suffix "smod" t +wgslBinOp (SQuot Int8 _) = WGSL.BinOpExp "/" +wgslBinOp (SQuot Int16 _) = WGSL.BinOpExp "/" +wgslBinOp (SQuot Int32 _) = WGSL.BinOpExp "/" +wgslBinOp (SQuot Int64 _) = call2 "squot_i64" +wgslBinOp (SRem Int8 _) = WGSL.BinOpExp "%" +wgslBinOp (SRem Int16 _) = WGSL.BinOpExp "%" +wgslBinOp (SRem Int32 _) = WGSL.BinOpExp "%" +wgslBinOp (SRem Int64 _) = call2 "srem_i64" +wgslBinOp (SMin Int64) = call2 "smin_i64" +wgslBinOp (SMin _) = call2 "min" +wgslBinOp (UMin t) = call2Suffix "umin" t +wgslBinOp (FMin _) = call2 "min" +wgslBinOp (SMax Int64) = call2 "smax_i64" +wgslBinOp (SMax _) = call2 "max" +wgslBinOp (UMax t) = call2Suffix "umax" t +wgslBinOp (FMax _) = call2 "max" +wgslBinOp (Shl t) = call2Suffix "shl" t +wgslBinOp (LShr t) = call2Suffix "lshr" t +wgslBinOp (AShr t) = call2Suffix "ashr" t +wgslBinOp (And _) = WGSL.BinOpExp "&" +wgslBinOp (Or _) = WGSL.BinOpExp "|" +wgslBinOp (Xor _) = WGSL.BinOpExp "^" +wgslBinOp (Pow t) = call2Suffix "pow" t +wgslBinOp (FPow _) = call2 "pow" +wgslBinOp LogAnd = call2 "log_and" +wgslBinOp LogOr = call2 "log_or" + +-- Because we (in e.g. scalar8.wgsl) make sure to always sign-extend an i8 value +-- across its whole i32 representation, we can just use the normal comparison +-- operators for smaller integers. The same applies for i16. +wgslCmpOp :: CmpOp -> WGSL.Exp -> WGSL.Exp -> WGSL.Exp +wgslCmpOp (CmpEq (IntType Int64)) = call2 "eq_i64" +wgslCmpOp (CmpEq _) = WGSL.BinOpExp "==" +wgslCmpOp (CmpUlt Int64) = call2 "ult_i64" +wgslCmpOp (CmpUlt _) = call2 "ult_i32" +wgslCmpOp (CmpUle Int64) = call2 "ule_i64" +wgslCmpOp (CmpUle _) = call2 "ule_i32" +wgslCmpOp (CmpSlt Int64) = call2 "slt_i64" +wgslCmpOp (CmpSlt _) = WGSL.BinOpExp "<" +wgslCmpOp (CmpSle Int64) = call2 "sle_i64" +wgslCmpOp (CmpSle _) = WGSL.BinOpExp "<=" +wgslCmpOp (FCmpLt _) = WGSL.BinOpExp "<" +wgslCmpOp (FCmpLe _) = WGSL.BinOpExp "<=" +wgslCmpOp CmpLlt = call2 "llt" +wgslCmpOp CmpLle = call2 "lle" + +-- Similarly to CmpOps above, the defaults work for smaller integers already +-- given our representation. +wgslUnOp :: UnOp -> WGSL.Exp -> WGSL.Exp +wgslUnOp (Neg (FloatType _)) = WGSL.UnOpExp "-" +wgslUnOp (Neg (IntType t)) = call1 $ "neg_" <> prettyText t +wgslUnOp (Neg _) = WGSL.UnOpExp "!" +wgslUnOp (Complement _) = WGSL.UnOpExp "~" +wgslUnOp (Abs Int64) = call1 "abs_i64" +wgslUnOp (Abs _) = call1 "abs" +wgslUnOp (FAbs _) = call1 "abs" +wgslUnOp (SSignum Int64) = call1 "ssignum_i64" +wgslUnOp (SSignum _) = call1 "sign" +wgslUnOp (USignum Int64) = call1 "usignum_i64" +wgslUnOp (USignum _) = call1 "usignum_i32" +wgslUnOp (FSignum _) = call1 "sign" + +wgslConvOp :: ConvOp -> WGSL.Exp -> WGSL.Exp +wgslConvOp op a = WGSL.CallExp (fun op) [a] + where + fun (ZExt Int8 Int16) = "zext_i8_i16" + fun (SExt Int8 Int16) = "sext_i8_i16" + fun (ZExt Int8 Int32) = "zext_i8_i32" + fun (SExt Int8 Int32) = "sext_i8_i32" + fun (ZExt Int8 Int64) = "zext_i8_i64" + fun (SExt Int8 Int64) = "sext_i8_i64" + fun (ZExt Int16 Int32) = "zext_i16_i32" + fun (SExt Int16 Int32) = "sext_i16_i32" + fun (ZExt Int16 Int64) = "zext_i16_i64" + fun (SExt Int16 Int64) = "sext_i16_i64" + fun (ZExt Int32 Int64) = "zext_i32_i64" + fun (SExt Int32 Int64) = "sext_i32_i64" + fun (ZExt Int16 Int8) = "trunc_i16_i8" + fun (SExt Int16 Int8) = "trunc_i16_i8" + fun (ZExt Int32 Int8) = "trunc_i32_i8" + fun (SExt Int32 Int8) = "trunc_i32_i8" + fun (ZExt Int64 Int8) = "trunc_i64_i8" + fun (SExt Int64 Int8) = "trunc_i64_i8" + fun (ZExt Int32 Int16) = "trunc_i32_i16" + fun (SExt Int32 Int16) = "trunc_i32_i16" + fun (ZExt Int64 Int16) = "trunc_i64_i16" + fun (SExt Int64 Int16) = "trunc_i64_i16" + fun (ZExt Int64 Int32) = "trunc_i64_i32" + fun (SExt Int64 Int32) = "trunc_i64_i32" + fun (FPToUI Float16 Int8) = "f16_to_u8" + fun (FPToUI Float16 Int16) = "f16_to_u16" + fun (FPToUI Float16 Int32) = "f16_to_u32" + fun (FPToUI Float32 Int8) = "f32_to_u8" + fun (FPToUI Float32 Int16) = "f32_to_u16" + fun (FPToUI Float32 Int32) = "f32_to_u32" + fun (FPToSI Float16 Int8) = "f16_to_i8" + fun (FPToSI Float16 Int16) = "f16_to_i16" + fun (FPToSI Float16 Int32) = "i32" + fun (FPToSI Float32 Int8) = "f32_to_i8" + fun (FPToSI Float32 Int16) = "f32_to_i16" + fun (FPToSI Float32 Int32) = "i32" + fun (UIToFP Int8 Float16) = "u8_to_f16" + fun (UIToFP Int16 Float16) = "u16_to_f16" + fun (UIToFP Int32 Float16) = "u32_to_f16" + fun (UIToFP Int64 Float16) = "u64_to_f16" + fun (UIToFP Int8 Float32) = "u8_to_f32" + fun (UIToFP Int16 Float32) = "u16_to_f32" + fun (UIToFP Int32 Float32) = "u32_to_f32" + fun (UIToFP Int64 Float32) = "u64_to_f32" + fun (SIToFP Int64 Float16) = "i64_to_f16" + fun (SIToFP _ Float16) = "f16" + fun (SIToFP Int64 Float32) = "i64_to_f32" + fun (SIToFP _ Float32) = "f32" + fun (IToB Int64) = "i64_to_bool" + fun (IToB _) = "bool" + fun (BToI Int8) = "bool_to_i8" + fun (BToI Int16) = "bool_to_i16" + fun (BToI Int32) = "i32" + fun (BToI Int64) = "bool_to_i64" + fun o = "not_implemented(" <> prettyText o <> ")" + +intLiteral :: IntValue -> WGSL.Exp +intLiteral (Int8Value v) = + WGSL.CallExp "norm_i8" [WGSL.IntExp $ fromIntegral v] +intLiteral (Int16Value v) = + WGSL.CallExp "norm_i16" [WGSL.IntExp $ fromIntegral v] +intLiteral (Int64Value v) = WGSL.CallExp "i64" [low, high] + where + -- Helper function transforming a Int64 -> Int32 -> WGSL.Exp. + -- If we just do Int64 -> WGSL.Exp directly, low/high can end + -- up being larger than what fits in an i32. + toWGSLInt :: Int64 -> WGSL.Exp + toWGSLInt x = WGSL.IntExp $ fromIntegral (fromIntegral x :: Int32) + low = toWGSLInt $ v Bits..&. 0xffffffff + high = toWGSLInt $ (v `Bits.shift` (-32)) Bits..&. 0xffffffff +intLiteral v = WGSL.IntExp (valueIntegral v) + +handleSpecialFloats :: T.Text -> Double -> WGSL.Exp +handleSpecialFloats s v + | isInfinite v, v > 0 = WGSL.CallExp (s <> "_inf") [] + | isInfinite v, v < 0 = WGSL.CallExp (s <> "_neg_inf") [] + | isNaN v = WGSL.CallExp (s <> "_nan") [] + | otherwise = WGSL.FloatExp v + +genFloatExp :: FloatValue -> WGSL.Exp +genFloatExp (Float16Value v) = handleSpecialFloats "f16" (convFloat v) +genFloatExp (Float32Value v) = handleSpecialFloats "f32" (convFloat v) +genFloatExp (Float64Value v) = handleSpecialFloats "f64" v + +genWGSLExp :: Exp -> KernelM WGSL.Exp +genWGSLExp (LeafExp name _) = WGSL.VarExp <$> getIdent name +genWGSLExp (ValueExp (IntValue v)) = pure $ intLiteral v +genWGSLExp (ValueExp (FloatValue v)) = pure $ genFloatExp v +genWGSLExp (ValueExp (BoolValue v)) = pure $ WGSL.BoolExp v +genWGSLExp (ValueExp UnitValue) = + error "should not attempt to generate unit expressions" +genWGSLExp (BinOpExp op e1 e2) = + liftM2 (wgslBinOp op) (genWGSLExp e1) (genWGSLExp e2) +genWGSLExp (CmpOpExp op e1 e2) = + liftM2 (wgslCmpOp op) (genWGSLExp e1) (genWGSLExp e2) +genWGSLExp (UnOpExp op e) = wgslUnOp op <$> genWGSLExp e +genWGSLExp (ConvOpExp op e) = wgslConvOp op <$> genWGSLExp e +genWGSLExp e = pure $ WGSL.StringExp $ " prettyText e <> ">" + +-- We support 64-bit arithmetic, but since WGSL does not have support for it, +-- we cannot use a 64-bit value as an index, so we have to truncate it to 32 +-- bits. +indexExp :: Count Elements (TExp Int64) -> KernelM WGSL.Exp +-- There are many occasions where we would end up extending to 64 bit and +-- immediately truncating, avoid that. +indexExp (Count (TPrimExp (ConvOpExp (SExt Int32 Int64) e))) = genWGSLExp e +indexExp c = (genWGSLExp . ConvOpExp (ZExt Int64 Int32) . untyped . unCount) c + +-- | Generate a struct declaration and corresponding uniform binding declaration +-- for all the scalar 'KernelUse's. Also generate a block of statements that +-- copies the struct fields into local variables so the kernel body can access +-- them unmodified. +genScalarDecls :: KernelM () +genScalarDecls = do + structName <- mkGlobalIdent "Scalars" + bufferName <- mkGlobalIdent "scalars" + uses <- asks (ImpGPU.kernelUses . krKernel) + + let scalarUses = [(nameToIdent name, typ) | ImpGPU.ScalarUse name typ <- uses] + scalars <- forM scalarUses $ + \(name, typ) -> do + let varPrimTyp = wgslPrimType typ + let fieldPrimTyp = case typ of + Bool -> WGSL.Int32 -- bool is not host-shareable + _ -> varPrimTyp + let wrapCopy e = case typ of + Bool -> WGSL.CallExp "bool" [e] + _ -> e + + addScalar fieldPrimTyp + addInitStmt $ WGSL.DeclareVar name (WGSL.Prim varPrimTyp) + addInitStmt $ + WGSL.Assign name (wrapCopy $ WGSL.FieldExp (WGSL.VarExp bufferName) name) + + pure (name, WGSL.Prim fieldPrimTyp) + + let scalarFields = case scalars of + [] -> [("_dummy_scalar", WGSL.Prim WGSL.Int32)] + sclrs -> sclrs + addDecl $ + WGSL.StructDecl $ + WGSL.Struct structName (map (uncurry WGSL.Field) scalarFields) + + slot <- assignBindSlot + let bufferAttribs = WGSL.bindingAttribs 0 slot + addDecl $ + WGSL.VarDecl bufferAttribs WGSL.Uniform bufferName (WGSL.Named structName) + +atomicOpArray :: ImpGPU.AtomicOp -> VName +atomicOpArray (ImpGPU.AtomicAdd _ _ n _ _) = n +atomicOpArray (ImpGPU.AtomicFAdd _ _ n _ _) = n +atomicOpArray (ImpGPU.AtomicSMax _ _ n _ _) = n +atomicOpArray (ImpGPU.AtomicSMin _ _ n _ _) = n +atomicOpArray (ImpGPU.AtomicUMax _ _ n _ _) = n +atomicOpArray (ImpGPU.AtomicUMin _ _ n _ _) = n +atomicOpArray (ImpGPU.AtomicAnd _ _ n _ _) = n +atomicOpArray (ImpGPU.AtomicOr _ _ n _ _) = n +atomicOpArray (ImpGPU.AtomicXor _ _ n _ _) = n +atomicOpArray (ImpGPU.AtomicCmpXchg _ _ n _ _ _) = n +atomicOpArray (ImpGPU.AtomicXchg _ _ n _ _) = n +atomicOpArray (ImpGPU.AtomicWrite _ n _ _) = n + +-- We declare all our variables as the signed type and re-interpret when +-- necessary for operations. +atomicOpType :: ImpGPU.AtomicOp -> (PrimType, Signedness) +atomicOpType (ImpGPU.AtomicAdd t _ _ _ _) = (IntType t, Signed) +atomicOpType (ImpGPU.AtomicFAdd t _ _ _ _) = (FloatType t, Signed) +atomicOpType (ImpGPU.AtomicSMax t _ _ _ _) = (IntType t, Signed) +atomicOpType (ImpGPU.AtomicSMin t _ _ _ _) = (IntType t, Signed) +atomicOpType (ImpGPU.AtomicUMax t _ _ _ _) = (IntType t, Signed) +atomicOpType (ImpGPU.AtomicUMin t _ _ _ _) = (IntType t, Signed) +atomicOpType (ImpGPU.AtomicAnd t _ _ _ _) = (IntType t, Signed) +atomicOpType (ImpGPU.AtomicOr t _ _ _ _) = (IntType t, Signed) +atomicOpType (ImpGPU.AtomicXor t _ _ _ _) = (IntType t, Signed) +atomicOpType (ImpGPU.AtomicCmpXchg t _ _ _ _ _) = (t, Signed) +atomicOpType (ImpGPU.AtomicXchg t _ _ _ _) = (t, Signed) +atomicOpType (ImpGPU.AtomicWrite t _ _ _) = (t, Signed) + +-- | Internally, memory buffers are untyped but WGSL requires us to annotate the +-- binding with a type. Search the kernel body for any reads and writes to the +-- given buffer and return all types it is accessed at. +-- The bool indicates an atomic type. The Signedness is only relevant for atomic +-- types as described for `atomicOpType`. +findMemoryTypes :: VName -> KernelM [(PrimType, Bool, Signedness)] +findMemoryTypes name = S.elems . find <$> asks (ImpGPU.kernelBody . krKernel) + where + find (ImpGPU.Write n _ t _ _ _) | n == name = S.singleton (t, False, Signed) + find (ImpGPU.Read _ n _ t _ _) | n == name = S.singleton (t, False, Signed) + find (ImpGPU.Copy t _ (n, _) _ _ _) | n == name = S.singleton (t, False, Signed) + find (ImpGPU.Copy t _ _ _ (n, _) _) | n == name = S.singleton (t, False, Signed) + find (Op (ImpGPU.Atomic _ op)) + | atomicOpArray op == name = + let (t, sgn) = atomicOpType op + in S.singleton (t, True, sgn) + find (s1 :>>: s2) = find s1 <> find s2 + find (For _ _ body) = find body + find (While _ body) = find body + find (If _ s1 s2) = find s1 <> find s2 + find _ = S.empty + +findSingleMemoryType :: VName -> KernelM (Maybe (PrimType, Bool, Signedness)) +findSingleMemoryType name = do + types <- findMemoryTypes name + let prims = nubOrd $ map (\(t, _, _) -> t) types + case prims of + [] -> pure Nothing + [prim] -> do + -- Only used at one primitive type. If it is an integer <=32 bit and there + -- are atomic accesses, make it the appropriate atomic type. Otherwise, + -- make sure there is only one type combination total. + let atomic = L.find (\(_, a, _) -> a) types + case atomic of + Just (_, _, sgn) + | canBeAtomic prim -> + if all (\(_, _, s) -> s == sgn) types + then pure $ Just (prim, True, sgn) + else error "Atomic type used at multiple signednesses" + Just (t, _, _) -> + error $ "Atomics not supported for values of type " <> show t + Nothing -> pure $ Just (prim, False, Signed) + _tooMany -> error "Buffer used at multiple types" + where + canBeAtomic (IntType Int64) = False + canBeAtomic (IntType _) = True + canBeAtomic (FloatType _) = True + canBeAtomic Bool = True + canBeAtomic _ = False + +-- | Generate binding declarations for memory buffers used by kernel. Produces +-- additional name replacements because it makes the binding names unique. +-- +-- We can't use the same trick as for e.g. scalars where we make a local copy to +-- avoid the name replacements because WGSL does not allow function-local +-- variables in the 'storage' address space. +genMemoryDecls :: KernelM () +genMemoryDecls = do + uses <- asks (ImpGPU.kernelUses . krKernel) + memUses <- catMaybes <$> sequence [withType n | ImpGPU.MemoryUse n <- uses] + mapM_ moduleDecl memUses + mapM_ rename memUses + where + withType name = do + typ <- findSingleMemoryType name + pure $ (nameToIdent name,) <$> typ + moduleDecl (name, typ) = do + ident <- mkGlobalIdent name + slot <- assignBindSlot + let (_, atomic, _) = typ + when atomic $ addAtomicMem ident + addDecl $ + WGSL.VarDecl + (WGSL.bindingAttribs 0 slot) + (WGSL.Storage WGSL.ReadWrite) + ident + (wgslBufferType typ Nothing) + rename (name, _) = mkGlobalIdent name >>= addRename name + +-- | Generate `override` declarations for kernel 'ConstUse's and +-- backend-provided values (like block size and lockstep width). +genConstAndBuiltinDecls :: KernelM () +genConstAndBuiltinDecls = do + kernel <- asks krKernel + + -- Start off with handling the block size parameters. + let blockDimExps = zip [0 ..] $ ImpGPU.kernelBlockSize kernel + blockDimNames <- mapM (builtinBlockSize . fst) blockDimExps + let blockDims = zip blockDimExps blockDimNames + let constBlockDims = [(i, n, e) | ((i, Right e), n) <- blockDims] + let dynBlockDims = [(i, n, e) | ((i, Left e), n) <- blockDims] + + forM_ blockDimNames $ + \n -> addOverride n (WGSL.Prim WGSL.Int32) zeroInit + + -- KernelConstExp block dims get generated into the general macro/override + -- machinery. + lift $ + mapM_ (uncurry addMacroDef) [(nameFromText n, e) | (_, n, e) <- constBlockDims] + + mapM_ (\(i, n, _e) -> addBlockDim i n True) dynBlockDims + mapM_ (\(i, n, _e) -> addBlockDim i n False) constBlockDims + + -- Next we generate builtin override declarations. + lsWidth <- builtinLockstepWidth + addOverride lsWidth (WGSL.Prim WGSL.Int32) zeroInit + lift $ addMacroDef (nameFromText lsWidth) $ ValueExp (IntValue (Int32Value 1)) + + -- And lastly we handle ConstUses. + let consts = [(n, e) | ImpGPU.ConstUse n e <- ImpGPU.kernelUses kernel] + + let mkLo e = + untyped $ + (TPrimExp e :: TPrimExp Int64 KernelConst) .&. 0x00000000ffffffff + let mkHi e = + untyped $ + (TPrimExp e :: TPrimExp Int64 KernelConst) .>>. 32 + + let mkConst (name, e) = do + let n = nameToIdent name + lo <- mkGlobalIdent (n <> "_lo") + hi <- mkGlobalIdent (n <> "_hi") + case primExpType e of + IntType Int64 -> do + addOverride lo (WGSL.Prim WGSL.Int32) zeroInit + addOverride hi (WGSL.Prim WGSL.Int32) zeroInit + lift $ addMacroDef (nameFromText lo) (mkLo e) + lift $ addMacroDef (nameFromText hi) (mkHi e) + addInitStmt $ + WGSL.Seq + (WGSL.DeclareVar n (WGSL.Prim wgslInt64)) + ( WGSL.Assign n $ + WGSL.CallExp "i64" [WGSL.VarExp lo, WGSL.VarExp hi] + ) + _ -> do + addOverride lo (WGSL.Prim WGSL.Int32) zeroInit + lift $ addMacroDef (nameFromText lo) e + addInitStmt $ + WGSL.Seq + (WGSL.DeclareVar n (WGSL.Prim wgslInt64)) + ( WGSL.Assign n $ + WGSL.CallExp "i64" [WGSL.VarExp lo, WGSL.IntExp 0] + ) + + mapM_ mkConst consts + where + zeroInit = Just $ WGSL.IntExp 0 + +nameToIdent :: VName -> WGSL.Ident +nameToIdent = zEncodeText . prettyText + +textToIdent :: T.Text -> WGSL.Ident +textToIdent = zEncodeText diff --git a/src/Futhark/CodeGen/RTS/C.hs b/src/Futhark/CodeGen/RTS/C.hs index 9981108605..0c7c6c3a2e 100644 --- a/src/Futhark/CodeGen/RTS/C.hs +++ b/src/Futhark/CodeGen/RTS/C.hs @@ -29,6 +29,7 @@ module Futhark.CodeGen.RTS.C backendsHipH, backendsCH, backendsMulticoreH, + backendsWebGPUH, ) where @@ -176,6 +177,11 @@ backendsMulticoreH :: T.Text backendsMulticoreH = $(embedStringFile "rts/c/backends/multicore.h") {-# NOINLINE backendsMulticoreH #-} +-- | @rts/c/backends/webgpu.h@ +backendsWebGPUH :: T.Text +backendsWebGPUH = $(embedStringFile "rts/c/backends/webgpu.h") +{-# NOINLINE backendsWebGPUH #-} + -- | @rts/c/copy.h@ copyH :: T.Text copyH = $(embedStringFile "rts/c/copy.h") diff --git a/src/Futhark/CodeGen/RTS/WGSL.hs b/src/Futhark/CodeGen/RTS/WGSL.hs new file mode 100644 index 0000000000..24692249e3 --- /dev/null +++ b/src/Futhark/CodeGen/RTS/WGSL.hs @@ -0,0 +1,96 @@ +{-# LANGUAGE TemplateHaskell #-} + +-- | Code snippets used by the WebGPU backend as part of WGSL shaders. +module Futhark.CodeGen.RTS.WGSL + ( scalar, + scalar8, + scalar16, + scalar32, + scalar64, + atomics, + wgsl_prelude, + lmad_copy, + map_transpose, + map_transpose_low_height, + map_transpose_low_width, + map_transpose_small, + map_transpose_large, + ) +where + +import Data.FileEmbed +import Data.Text qualified as T + +-- | @rts/wgsl/scalar.wgsl@ +scalar :: T.Text +scalar = $(embedStringFile "rts/wgsl/scalar.wgsl") +{-# NOINLINE scalar #-} + +-- | @rts/wgsl/scalar8.wgsl@ +scalar8 :: T.Text +scalar8 = $(embedStringFile "rts/wgsl/scalar8.wgsl") +{-# NOINLINE scalar8 #-} + +-- | @rts/wgsl/scalar16.wgsl@ +scalar16 :: T.Text +scalar16 = $(embedStringFile "rts/wgsl/scalar16.wgsl") +{-# NOINLINE scalar16 #-} + +-- | @rts/wgsl/scalar32.wgsl@ +scalar32 :: T.Text +scalar32 = $(embedStringFile "rts/wgsl/scalar32.wgsl") +{-# NOINLINE scalar32 #-} + +-- | @rts/wgsl/scalar64.wgsl@ +scalar64 :: T.Text +scalar64 = $(embedStringFile "rts/wgsl/scalar64.wgsl") +{-# NOINLINE scalar64 #-} + +-- | @rts/wgsl/atomics.wgsl@ +atomics :: T.Text +atomics = $(embedStringFile "rts/wgsl/atomics.wgsl") +{-# NOINLINE atomics #-} + +wgsl_prelude :: T.Text +wgsl_prelude = + -- Put scalar32 in front of the other integer types since they are all + -- internally represented using i32. + mconcat + [ "enable f16;\n", + scalar, + scalar32, + scalar8, + scalar16, + scalar64, + atomics + ] + +-- | @rts/wgsl/lmad_copy.wgsl@ +lmad_copy :: T.Text +lmad_copy = $(embedStringFile "rts/wgsl/lmad_copy.wgsl") +{-# NOINLINE lmad_copy #-} + +-- | @rts/wgsl/map_transpose.wgsl@ +map_transpose :: T.Text +map_transpose = $(embedStringFile "rts/wgsl/map_transpose.wgsl") +{-# NOINLINE map_transpose #-} + +-- | @rts/wgsl/map_transpose_low_height.wgsl@ +map_transpose_low_height :: T.Text +map_transpose_low_height = $(embedStringFile "rts/wgsl/map_transpose_low_height.wgsl") +{-# NOINLINE map_transpose_low_height #-} + +-- | @rts/wgsl/map_transpose_low_width.wgsl@ +map_transpose_low_width :: T.Text +map_transpose_low_width = $(embedStringFile "rts/wgsl/map_transpose_low_width.wgsl") +{-# NOINLINE map_transpose_low_width #-} + +-- | @rts/wgsl/map_transpose_small.wgsl@ +map_transpose_small :: T.Text +map_transpose_small = $(embedStringFile "rts/wgsl/map_transpose_small.wgsl") +{-# NOINLINE map_transpose_small #-} + +-- | @rts/wgsl/map_transpose_large.wgsl@ +map_transpose_large :: T.Text +map_transpose_large = $(embedStringFile "rts/wgsl/map_transpose_large.wgsl") +{-# NOINLINE map_transpose_large #-} diff --git a/src/Futhark/CodeGen/RTS/WebGPU.hs b/src/Futhark/CodeGen/RTS/WebGPU.hs new file mode 100644 index 0000000000..1d2bee6152 --- /dev/null +++ b/src/Futhark/CodeGen/RTS/WebGPU.hs @@ -0,0 +1,33 @@ +{-# LANGUAGE TemplateHaskell #-} + +-- | Code snippets used by the WebGPU backend. +module Futhark.CodeGen.RTS.WebGPU + ( serverWsJs, + utilJs, + valuesJs, + wrappersJs, + ) +where + +import Data.FileEmbed +import Data.Text qualified as T + +-- | @rts/webgpu/server_ws.js@ +serverWsJs :: T.Text +serverWsJs = $(embedStringFile "rts/webgpu/server_ws.js") +{-# NOINLINE serverWsJs #-} + +-- | @rts/webgpu/util.js@ +utilJs :: T.Text +utilJs = $(embedStringFile "rts/webgpu/util.js") +{-# NOINLINE utilJs #-} + +-- | @rts/webgpu/values.js@ +valuesJs :: T.Text +valuesJs = $(embedStringFile "rts/webgpu/values.js") +{-# NOINLINE valuesJs #-} + +-- | @rts/webgpu/wrappers.js@ +wrappersJs :: T.Text +wrappersJs = $(embedStringFile "rts/webgpu/wrappers.js") +{-# NOINLINE wrappersJs #-} diff --git a/src/Futhark/Test/WebGPUTest.hs b/src/Futhark/Test/WebGPUTest.hs new file mode 100644 index 0000000000..bda2ca8d14 --- /dev/null +++ b/src/Futhark/Test/WebGPUTest.hs @@ -0,0 +1,177 @@ +module Futhark.Test.WebGPUTest + ( generateTests, + ) +where + +import Control.Monad.IO.Class (MonadIO, liftIO) +import Data.List (foldl') +import Data.Map qualified as M +import Data.Maybe (mapMaybe) +import Data.Text qualified as T +import Futhark.CodeGen.ImpCode.WebGPU +import Futhark.CodeGen.ImpGen.WebGPU (compileProg) +import Futhark.IR.GPUMem qualified as F +import Futhark.MonadFreshNames +import Futhark.Test.Spec +import Futhark.Test.Values qualified as V +import Futhark.Util.Pretty +import Futhark.Util + +generateTests :: + (MonadFreshNames m, MonadIO m) => + FilePath -> + F.Prog F.GPUMem -> + m T.Text +generateTests path prog = do + compiled <- snd <$> compileProg prog + spec <- liftIO $ testSpecFromProgramOrDie (path <> ".fut") + let tests = testCasesLiteral spec + let info = kernelInfoLiteral compiled + let shader = shaderLiteral compiled + pure (tests <> "\n\n" <> info <> "\n\n" <> shader) + +shaderLiteral :: Program -> T.Text +shaderLiteral prog = + "window.shader = `\n" + <> webgpuPrelude prog + <> "\n" + <> webgpuProgram prog + <> "\n`;" + +-- window.kernels = [ +-- { name: 'some_vname_5568', +-- overrides: ['override', 'declarations'], +-- scalarsBindSlot: 0, +-- bindSlots: [1, 2, 3], +-- }, +-- ]; +kernelInfoLiteral :: Program -> T.Text +kernelInfoLiteral prog = "window.kernels = " <> docText fmtInfos <> ";" + where + infos = M.toList $ webgpuKernels prog + fmtInfos = "[" indent 2 (commastack $ map fmtInfo infos) "]" + fmtInfo (name, ki) = + "{" + indent + 2 + ( "name: '" + <> pretty (zEncodeText (nameToText name)) + <> "'," + "overrides: [" + <> commasep + (map (\o -> "'" <> pretty o <> "'") (overrideNames ki)) + <> "]," + "scalarsBindSlot: " + <> pretty (scalarsBindSlot ki) + <> "," + "bindSlots: " + <> pretty (memBindSlots ki) + <> "," + ) + "}" + +-- window.tests = [ +-- { entry: 'someName', +-- runs: [ +-- { +-- inputTypes: ['i32'], +-- input: [[0, 1, 2, 3]], +-- expected: [[0, 2, 4, 6]], +-- }, +-- ], +-- }, +-- ]; +testCasesLiteral :: ProgramTest -> T.Text +testCasesLiteral (ProgramTest _ _ (RunCases ios _ _)) = + let specs = map ((<> ",\n") . prettyText . mkTestSpec) ios + in "window.tests = [\n" <> foldl' (<>) "" specs <> "];" +testCasesLiteral t = "// Unsupported test: " <> testDescription t + +data JsTestSpec = JsTestSpec + { _jsEntryPoint :: T.Text, + _jsRuns :: [JsTestRun] + } + +data JsTestRun = JsTestRun + { _jsInputTypes :: [V.PrimType], + _jsInput :: [V.Value], + _jsExpectedTypes :: [V.PrimType], + _jsExpected :: [V.Value] + } + +mkTestSpec :: InputOutputs -> JsTestSpec +mkTestSpec (InputOutputs entry runs) = JsTestSpec entry (mapMaybe mkRun runs) + +mkRun :: TestRun -> Maybe JsTestRun +mkRun + ( TestRun + _ + (Values vals) + (Succeeds (Just (SuccessValues (Values expected)))) + _ + _ + ) = + let inputTyps = map V.valueElemType vals + expectedTyps = map V.valueElemType expected + in Just $ JsTestRun inputTyps vals expectedTyps expected +mkRun _ = Nothing + +instance Pretty JsTestRun where + pretty (JsTestRun inputTyps input expectedTyps expected) = + "{" + indent + 2 + ( "inputTypes: [" + <> commasep + ( map + (\t -> "'" <> pretty (V.primTypeText t) <> "'") + inputTyps + ) + <> "]," + "input: " + <> fmt inputTyps input + <> "," + "expectedTypes: [" + <> commasep + ( map + (\t -> "'" <> pretty (V.primTypeText t) <> "'") + expectedTyps + ) + <> "]," + "expected: " + <> fmt expectedTyps expected + <> "," + ) + "}" + where + fmtVal V.I64 v = pretty v <> "n" + fmtVal V.U64 v = pretty v <> "n" + fmtVal _ v = pretty v + fmtArrRaw typ vs = "[" <> commasep (map (fmtVal typ) vs) <> "]" + -- Hacky way to avoid the 'i32', 'i64' etc. suffixes as they are not valid + -- JS. + fixAnnots typ d = pretty $ T.replace (V.primTypeText typ) "" (docText d) + fixSpecials d = + pretty $ + T.replace ".nan" "NaN" $ + T.replace ".inf" "Infinity" $ + docText d + fmtArray typ vs = + fixSpecials $ + fixAnnots typ $ + fmtArrRaw typ (V.valueElems vs) + fmt typs vss = "[" <> commasep (zipWith fmtArray typs vss) <> "]" + +instance Pretty JsTestSpec where + pretty (JsTestSpec entry runs) = + "{" + indent + 2 + ( "entry: '" + <> pretty entry + <> "'," + "runs: [" + indent 2 (commastack $ map pretty runs) + "]," + ) + "}" diff --git a/tests/.gitignore b/tests/.gitignore index 2c367988df..980e22f26c 100644 --- a/tests/.gitignore +++ b/tests/.gitignore @@ -7,6 +7,7 @@ /**/*.expected /**/*.actual /**/*.wasm +/**/*.js /**/*.json /**/*.ispc /**/*.cache diff --git a/tests/primitive/naninf16.fut b/tests/primitive/naninf16.fut new file mode 100644 index 0000000000..fae8815541 --- /dev/null +++ b/tests/primitive/naninf16.fut @@ -0,0 +1,68 @@ +-- NaN and inf must work. + +-- == +-- entry: eqNaN +-- input { [2f16, f16.nan, f16.inf, -f16.inf] } +-- output { [false, false, false, false] } + +-- == +-- entry: ltNaN +-- input { [2f16, f16.nan, f16.inf, -f16.inf] } +-- output { [false, false, false, false] } + +-- == +-- entry: lteNaN +-- input { [2f16, f16.nan, f16.inf, -f16.inf] } +-- output { [false, false, false, false] } + +-- == +-- entry: ltInf +-- input { [2f16, f16.nan, f16.inf, -f16.inf] } +-- output { [true, false, false, true] } + +-- == +-- entry: lteInf +-- input { [2f16, f16.nan, f16.inf, -f16.inf] } +-- output { [true, false, true, true] } + +-- == +-- entry: diffInf +-- input { [2f16, f16.nan, f16.inf, -f16.inf] } +-- output { [true, false, false, false] } + +-- == +-- entry: sumNaN +-- input { [2f16, f16.nan, f16.inf, -f16.inf] } +-- output { [true, true, true, true] } + +-- == +-- entry: sumInf +-- input { [2f16, f16.nan, f16.inf, -f16.inf] } +-- output { [true, false, true, false] } + +-- == +-- entry: log2 +-- input { [2f16, f16.nan, f16.inf, -f16.inf] } +-- output { [false, true, false, true] } + +-- == +-- entry: log10 +-- input { [10f16, f16.nan, f16.inf, -f16.inf] } +-- output { [false, true, false, true] } + +-- == +-- entry: log1p +-- input { [-2f16, -1f16, 2f16, f16.nan, f16.inf, -f16.inf] } +-- output { [true, false, false, true, false, true] } + +entry eqNaN = map (\x -> x == f16.nan) +entry ltNaN = map (\x -> x < f16.nan) +entry lteNaN = map (\x -> x <= f16.nan) +entry ltInf = map (\x -> x < f16.inf) +entry lteInf = map (\x -> x <= f16.inf) +entry diffInf = map (\x -> x - f16.inf < x + f16.inf) +entry sumNaN = map (\x -> f16.isnan (x + f16.nan)) +entry sumInf = map (\x -> f16.isinf (x + f16.inf)) +entry log2 = map (\x -> f16.isnan (f16.log2 (x))) +entry log10 = map (\x -> f16.isnan (f16.log10 (x))) +entry log1p = map (\x -> f16.isnan (f16.log1p (x))) diff --git a/tools/browser_test.py b/tools/browser_test.py new file mode 100755 index 0000000000..1f483d37df --- /dev/null +++ b/tools/browser_test.py @@ -0,0 +1,330 @@ +#!/usr/bin/env python3 + +import argparse +import base64 +import json +import shlex +import subprocess +import sys +import os +from io import BytesIO + +import asyncio +import aiohttp +from aiohttp import web +import numpy as np + +from selenium import webdriver + +parser = argparse.ArgumentParser() +parser.add_argument( + "program", help="the program to wrap (without .js or other file extension)" +) +parser.add_argument( + "--no-server-proxy", + help=( + "only act as HTTP server, do not proxy Futhark server protocol\n" + "(implies --no-browser)" + ), + action="store_true", +) +parser.add_argument( + "--no-browser", + help=( + "do not start a browser, instead wait for one to connect.\n" + "Can also be set via NO_BROWSER=1 env variable." + ), + action="store_true", +) +parser.add_argument( + "--show-browser", + help=( + "disable headless mode for browser.\n" + "Can also be set via HEADLESS=0 env variable." + ), + action="store_true", +) +parser.add_argument( + "--web-driver", + help=( + "URL of a remote WebDriver to connec to.\n" + "Can also be set via WEB_DRIVER_URL env variable." + ), +) +parser.add_argument( + "--log", + help=( + "log file for debug output.\n" + "Can also be set via LOG_FILE env variable." + ), +) +args = parser.parse_args() + +program_name = args.program +if program_name.startswith("./"): + program_name = program_name[2:] +script_name = program_name + ".js" +wasm_name = program_name + ".wasm" +wasm_map_name = program_name + ".wasm.map" +source_name = program_name + ".c" + +log_path = os.environ.get("LOG_FILE") +if log_path is None: + log_path = args.log + +log_file = None +if log_path is not None: + log_file = open(log_path, "w") + +remote_driver_url = os.environ.get("WEB_DRIVER_URL") +if remote_driver_url is None: + remote_driver_url = args.web_driver + +no_browser_env = os.environ.get("NO_BROWSER") +no_browser = None +if no_browser_env == "0": + no_browser = False +elif no_browser_env == "1": + no_browser = True +elif no_browser is None: + no_browser = args.no_browser + +headless_env = os.environ.get("HEADLESS") +headless = None +if headless_env == "0": + headless = False +elif headless_env == "1": + headless = True +elif headless is None: + headless = not args.show_browser + + +def eprint(*args, **kwargs): + print(*args, file=sys.stderr, flush=True, **kwargs) + if log_file is not None: + print(*args, file=log_file, flush=True, **kwargs) + + +index_page = ( + f"\n" + f"\n" + f"\n" + f" futhark test\n" + f' \n' + f"\n" + f"\n" + f"

futhark test

\n" + f"\n" + f"" +) + +default_headers = { + # This makes the site an "isolated context", most + # notably improving timing measurement resolution + # from 100 microseconds to 5 microseconds. + "Cross-Origin-Opener-Policy": "same-origin", + "Cross-Origin-Embedder-Policy": "require-corp", +} + + +async def handle_index(request): + return web.Response( + text=index_page, content_type="text/html", headers=default_headers + ) + + +async def handle_file(request): + file = request.rel_url.path + if not request.path.startswith("/tmp/"): + file = file.lstrip("/") + + content_type = "text/plain" + if file.endswith(".js"): + content_type = "text/javascript" + elif file.endswith(".wasm"): + content_type = "application/wasm" + elif file.endswith(".map"): + content_type = "application/json" + + contents = b"" + with open(file, "rb") as f: + contents = f.read() + + return web.Response( + body=contents, content_type=content_type, headers=default_headers + ) + + +def wrap_store_resp(fname, resp): + data = base64.b64decode(resp["data"].encode("utf-8")) + with open(fname, "wb") as f: + f.write(data) + return "" + + +async def handle_ws(request): + ws = web.WebSocketResponse(max_msg_size=2**30) + await ws.prepare(request) + + toWS = request.app["toWS"] + toStdIO = request.app["toStdIO"] + + # Notify that we have a browser connected + toStdIO.put_nowait("connected") + + while True: + cmd, args = await toWS.get() + + if cmd == "close": + break + + orig_args = args + if cmd == "store": + args = args[1:] + + await ws.send_json({"cmd": cmd, "args": args}) + msg = await ws.receive() + + if msg.type == aiohttp.WSMsgType.ERROR: + eprint("ws connection closed with exception %s" % ws.exception()) + break + + resp = json.loads(msg.data) + eprint("Got response:", resp) + + text = "" + if cmd == "store" and resp["status"] == "ok": + text = wrap_store_resp(orig_args[0], resp) + else: + text = resp["text"] + + toStdIO.put_nowait((resp["status"], text)) + + if not ws.closed: + await ws.close() + + eprint("WS connection closed, stopping server") + app["stop"].set() + + +def start_browser(): + options = webdriver.ChromeOptions() + + if headless: + options.add_argument("--headless=new") + + if remote_driver_url is not None: + driver = webdriver.Remote( + command_executor=remote_driver_url, options=options + ) + else: + # Need these extra options when running properly on Linux, but + # specifying them on Windows is not allowed. For now, just assume a + # remote driver will run on Windows and a local one on Linux, this check + # might need to be adjusted if we ever use remote drivers where the + # remote is also Linux. + options.add_argument("--enable-unsafe-webgpu") + options.add_argument("--enable-features=Vulkan") + + if headless: + # https://developer.chrome.com/blog/supercharge-web-ai-testing#enable-webgpu + options.add_argument("--no-sandbox") + options.add_argument("--use-angle=vulkan") + options.add_argument("--disable-vulkan-surface") + + driver = webdriver.Chrome(options=options) + + loop = asyncio.get_running_loop() + get_task = loop.run_in_executor(None, driver.get, "http://localhost:8100") + + return driver, get_task + + +def stop_browser(driver): + driver.quit() + + +async def start_server(app, toWS, toStdIO): + app["toWS"] = toWS + app["toStdIO"] = toStdIO + + app["stop"] = asyncio.Event() + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "0.0.0.0", 8100) + + await site.start() + + if not no_browser: + driver, get_task = start_browser() + + await app["stop"].wait() + if not no_browser: + await get_task + await runner.cleanup() + + if not no_browser: + stop_browser(driver) + + +app = web.Application() +app.add_routes( + [ + web.get("/", handle_index), + web.get("/index.html", handle_index), + web.get("/ws", handle_ws), + web.get("/{file:.*}", handle_file) # Catch-all for serving input files from /tmp/ or data/ directories + ] +) + + +async def read_stdin_line(): + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, sys.stdin.readline) + + +async def handle_stdio(toWS, toStdIO): + # Wait for an initial signal that a web browser client has connected. + await toStdIO.get() + + eprint("Browser client detected, starting Futhark server protocol") + print("%%% OK", flush=True) + + while True: + line = await read_stdin_line() + eprint("Got line:", line.rstrip()) + + if line.strip() == "": + toWS.put_nowait(("close", [])) + break + + command, *args = shlex.split(line) + + toWS.put_nowait((command, args)) + status, text = await toStdIO.get() + + if status == "ok": + print(text, flush=True) + print("%%% OK", flush=True) + else: + print("%%% FAILURE", flush=True) + print(text, flush=True) + print("%%% OK", flush=True) + + +async def main(): + toWS = asyncio.Queue() + toStdIO = asyncio.Queue() + if args.no_server_proxy: + eprint(f"Wrapping {script_name}; only providing web server") + eprint("Hosting at 0.0.0.0:8100") + await start_server(app, toWS, toStdIO) + else: + eprint(f"Wrapping {script_name}; proxying the Futhark server protocol") + await asyncio.gather( + start_server(app, toWS, toStdIO), handle_stdio(toWS, toStdIO) + ) + + +asyncio.run(main()) diff --git a/tools/webgpu-tester/.gitignore b/tools/webgpu-tester/.gitignore new file mode 100644 index 0000000000..7fc242aeb7 --- /dev/null +++ b/tools/webgpu-tester/.gitignore @@ -0,0 +1,3 @@ +*.js +!main.js +tests/ diff --git a/tools/webgpu-tester/assemble_tests.fish b/tools/webgpu-tester/assemble_tests.fish new file mode 100755 index 0000000000..e1d8adfa58 --- /dev/null +++ b/tools/webgpu-tester/assemble_tests.fish @@ -0,0 +1,20 @@ +#!/bin/fish + +set test_dir "tools/webgpu-tester/tests" +set meta_file "$test_dir/test-files.json" + +mkdir -p $test_dir + +echo -n "[" > $meta_file + +set delim "" +for p in $argv + set f (path basename "$p") + futhark dev --gpu-mem --test-webgpu-kernels "$p" > "$test_dir/$f.js" + + echo "$delim" >> $meta_file + set delim "," + echo -n " \"$f\"" >> $meta_file +end + +echo -e "\n]" >> $meta_file diff --git a/tools/webgpu-tester/index.html b/tools/webgpu-tester/index.html new file mode 100644 index 0000000000..f9d69a2ff4 --- /dev/null +++ b/tools/webgpu-tester/index.html @@ -0,0 +1,37 @@ + + + + WebGPU Compute Testing + + + + +

WebGPU WGSL Testing

+ + + +
    + + diff --git a/tools/webgpu-tester/main.js b/tools/webgpu-tester/main.js new file mode 100644 index 0000000000..8ef6db2581 --- /dev/null +++ b/tools/webgpu-tester/main.js @@ -0,0 +1,375 @@ +function typeSize(type) { + if (type == 'bool') { return 1; } + if (type == 'i8') { return 1; } + if (type == 'u8') { return 1; } + if (type == 'i16') { return 2; } + if (type == 'u16') { return 2; } + if (type == 'i32') { return 4; } + if (type == 'u32') { return 4; } + if (type == 'i64') { return 8; } + if (type == 'u64') { return 8; } + if (type == 'f32') { return 4; } + throw "unsupported type"; +} + +function alignArraySize(len) { + const r = len % 4; + return r ? len + (4 - r) : len; +} + +function toTypedArray(array, type) { + if (type == 'bool') { + const conv = array.map((x) => x ? 1 : 0); + return new Int8Array(conv); + } + if (type == 'i8') { return new Int8Array(array); } + if (type == 'u8') { return new Uint8Array(array); } + if (type == 'i16') { return new Int16Array(array); } + if (type == 'u16') { return new Uint16Array(array); } + if (type == 'i32') { return new Int32Array(array); } + if (type == 'u32') { return new Uint32Array(array); } + if (type == 'i64') { + const dest = new Int32Array(array.length * 2); + for (let i = 0; i < array.length; i++) { + dest[i*2] = Number(BigInt.asIntN(32, + array[i] & 0xffffffffn)); + dest[i*2+1] = Number(BigInt.asIntN(32, + (array[i] >> 32n) & 0xffffffffn)); + } + return dest; + } + if (type == 'u64') { + const dest = new Uint32Array(array.length * 2); + for (let i = 0; i < array.length; i++) { + dest[i*2] = Number(BigInt.asUintN(32, + array[i] & 0xffffffffn)); + dest[i*2+1] = Number(BigInt.asUintN(32, + (array[i] >> 32n) & 0xffffffffn)); + } + return dest; + } + if (type == 'f32') { return new Float32Array(array); } + throw "unsupported type"; +} + +function arrayBufferToTypedArray(buffer, type) { + if (type == 'bool') { return new Int8Array(buffer); } + if (type == 'i8') { return new Int8Array(buffer); } + if (type == 'u8') { return new Uint8Array(buffer); } + if (type == 'i16') { return new Int16Array(buffer); } + if (type == 'u16') { return new Uint16Array(buffer); } + if (type == 'i32') { return new Int32Array(buffer); } + if (type == 'u32') { return new Uint32Array(buffer); } + if (type == 'i64') { return new Int32Array(buffer); } + if (type == 'u64') { return new Uint32Array(buffer); } + if (type == 'f32') { return new Float32Array(buffer); } + throw "unsupported type"; +} + +function compareExpected(val, expected, type) { + if (type == 'f32') { + const tolerance = 0.002; + if (isNaN(val) && isNaN(expected)) { return true; } + if (!isFinite(val) && !isFinite(expected)) { + return Math.sign(val) == Math.sign(expected); + } + return Math.abs(val - expected) <= Math.abs(tolerance * expected); + } + return val == expected; +} + +async function runTest(device, shaderModule, testInfo) { + // Find kernel corresponding to entry. + let kernelInfo = undefined; + for (const k of window.kernels) { + if (k.name.includes(testInfo.entry + "zi")) { + kernelInfo = k; + break; + } + } + if (kernelInfo === undefined) { + console.error("Could not find kernel info for", testInfo); + return {0: "❗ failed to find kernel info"}; + } + + const templateRun = testInfo.runs[0]; + + // Create input buffers. + let inputBuffers = []; + for (let i = 0; i < templateRun.input.length; i++) { + const ty = templateRun.inputTypes[i]; + const inputElemSize = typeSize(ty); + + // Find maximum required size for buffer. + let maxLength = 0; + for (const run of testInfo.runs) { + if (run.input[i].length > maxLength) { + maxLength = run.input[i].length; + } + } + + const buffer = device.createBuffer({ + size: alignArraySize(maxLength * inputElemSize), + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST, + }); + inputBuffers.push(buffer); + } + + + let maxLength = 0; + for (const run of testInfo.runs) { + if (run.expected[0].length > maxLength) { + maxLength = run.expected[0].length; + } + } + + const outputType = templateRun.expectedTypes[0]; + const outputElemSize = typeSize(outputType); + + let outputBuffer = device.createBuffer({ + size: alignArraySize(maxLength * outputElemSize), + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC, + }); + let stagingBuffer = device.createBuffer({ + size: alignArraySize(maxLength * outputElemSize), + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + }); + + let scalarBuffer = device.createBuffer({ + size: 8, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, + }); + + let bglEntries = [ + { binding: kernelInfo.scalarsBindSlot, visibility: GPUShaderStage.COMPUTE, + buffer: { type: "uniform" } } + ]; + let bgEntries = [ + { binding: kernelInfo.scalarsBindSlot, resource: { buffer: scalarBuffer } } + ]; + for (let i = 0; i < inputBuffers.length; i++) { + bglEntries.push({ + binding: kernelInfo.bindSlots[i], + visibility: GPUShaderStage.COMPUTE, + buffer: { type: "storage" } + }); + bgEntries.push({ + binding: kernelInfo.bindSlots[i], + resource: { buffer: inputBuffers[i] } + }); + } + bglEntries.push({ + binding: kernelInfo.bindSlots[inputBuffers.length], + visibility: GPUShaderStage.COMPUTE, + buffer: { type: "storage" } } + ); + bgEntries.push({ + binding: kernelInfo.bindSlots[inputBuffers.length], + resource: { buffer: outputBuffer } + }); + const bgl = device.createBindGroupLayout({ + entries: bglEntries + }); + const bg = device.createBindGroup({ layout: bgl, entries: bgEntries }); + + const block_size = 256; + let overrideConsts = {}; + for (const override of kernelInfo.overrides) { + if (override.includes('lockstep_width')) { + overrideConsts[override] = 1; + } + if (override.includes('block_size') + || override.includes('block_sizze')) { + overrideConsts[override] = block_size; + } + } + + const pipeline = device.createComputePipeline({ + layout: device.createPipelineLayout({ bindGroupLayouts: [bgl] }), + compute: { + module: shaderModule, entryPoint: kernelInfo.name, + constants: overrideConsts, + } + }); + + let results = {}; + + for (let ri = 0; ri < testInfo.runs.length; ri++) { + const run = testInfo.runs[ri]; + const length = run.input[0].length; + const alignedOutLength = alignArraySize(length * outputElemSize); + + let hScalars = new Int32Array(2); + hScalars[0] = length; + hScalars[1] = 0; + device.queue.writeBuffer(scalarBuffer, 0, hScalars, 0); + + for (let i = 0; i < run.input.length; i++) { + const input = run.input[i]; + const inputType = run.inputTypes[i]; + + let hInput = toTypedArray(input, inputType); + let alignedLen = alignArraySize(hInput.byteLength); + if (alignedLen != hInput.byteLength) { + hInput = hInput.buffer.transfer(alignedLen); + } + device.queue.writeBuffer(inputBuffers[i], 0, hInput, 0); + } + + const commandEncoder = device.createCommandEncoder(); + const passEncoder = commandEncoder.beginComputePass(); + passEncoder.setPipeline(pipeline); + + passEncoder.setBindGroup(0, bg); + + passEncoder.dispatchWorkgroups(Math.ceil(length / block_size)); + passEncoder.end(); + commandEncoder.copyBufferToBuffer( + outputBuffer, 0, stagingBuffer, 0, alignedOutLength); + device.queue.submit([commandEncoder.finish()]); + + await stagingBuffer.mapAsync(GPUMapMode.READ, 0, alignedOutLength); + const stagingMapped = stagingBuffer.getMappedRange(0, alignedOutLength); + const data = arrayBufferToTypedArray(stagingMapped.slice(), outputType); + stagingBuffer.unmap(); + + const expected = toTypedArray(run.expected[0], outputType); + + let errors = []; + for (let i = 0; i < expected.length; i++) { + if (!compareExpected(data[i], expected[i], outputType)) { + errors.push({i: i, expect: expected[i], got: data[i]}); + } + } + + if (errors.length > 0) { + let msg = `Test for entry ${testInfo.entry}, run ${ri}: FAIL:\n`; + for (const err of errors) { + msg += ` [${err.i}] expected ${err.expect} got ${err.got}\n`; + } + console.error(msg); + + results[ri] = "❌"; + } + else { + console.log(`Test for entry ${testInfo.entry}, run ${ri}: PASS\n`); + results[ri] = "✅"; + } + } + + return results; +} + +async function getTestNames() { + const urlParams = new URLSearchParams(window.location.search); + + if (urlParams.get('tests-file')) { + const response = await fetch(urlParams.get('tests-file')); + const tests = await response.json(); + return tests; + } + + return [urlParams.get('test')]; +} + +async function evalTestInfo(testName) { + const testPath = `./tests/${testName}.js`; + console.log(`Test info from: ${testPath}`); + + const response = await fetch(testPath); + const testInfo = await response.text(); + // Executes the JS in testInfo in global scope instead of in the local + // scope. + (1, eval)(testInfo); +} + +async function init() { + if (!navigator.gpu) { + throw Error("WebGPU not supported."); + } + + const adapter = await navigator.gpu.requestAdapter(); + if (!adapter) { + throw Error("Couldn't request WebGPU adapter."); + } + + const info = await adapter.requestAdapterInfo(); + console.log("Adapter info: ", info); + + const device = await adapter.requestDevice(); + console.log("Acquired device: ", device); + + const testNames = await getTestNames(); + + var results = {}; + + for (const testName of testNames) { + try { + await evalTestInfo(testName); + } catch (e) { + results[testName] = { compiled: false, res: { + "[❌ could not load testInfo:]": {0: e} } }; + continue; + } + + const shaderModule = device.createShaderModule({ + code: shader, + }); + + const shaderInfo = await shaderModule.getCompilationInfo(); + console.log(`${testName}: Shader compilation info:`, shaderInfo); + + if (shaderInfo.messages.length > 0) { + results[testName] = { compiled: false, res: {} }; + } + else { + try { + var testResults = {}; + for (const test of window.tests) { + const r = await runTest(device, shaderModule, test); + testResults[test.entry] = r; + } + results[testName] = { compiled: true, res: testResults }; + } catch (e) { + results[testName] = { compiled: true, res: { + "[❌ error]": {0: e} } }; + continue; + } + } + } + + renderTestResults(results); +} + +function renderTestResults(results) { + const resultsContainer = document.getElementById("results"); + for (const [testName, testResults] of Object.entries(results)) { + const container = document.createElement("li"); + + const h = document.createElement("h3"); + h.innerHTML = testName; + container.appendChild(h); + + const c = document.createElement("span"); + c.classList = "compiled"; + c.innerHTML = "Compiled: " + (testResults.compiled ? "✅" : "❌"); + container.appendChild(c); + + for (const [entry, entryRes] of Object.entries(testResults.res)) { + const eh = document.createElement("h4"); + eh.innerHTML = entry; + container.appendChild(eh); + + for (const [run, runRes] of Object.entries(entryRes)) { + const r = document.createElement("span"); + r.classList = "run"; + r.innerHTML = run + ": " + runRes; + container.appendChild(r); + } + } + + resultsContainer.appendChild(container); + } +} + +init();