From a7da0f6c56519bdebad373dc563fc1bd3bf91fa3 Mon Sep 17 00:00:00 2001 From: David Duvenaud Date: Mon, 14 Jun 2021 11:34:22 -0400 Subject: [PATCH 1/2] Added initial version of octtree raytracer. --- examples/lodtracer.dx | 532 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 532 insertions(+) create mode 100644 examples/lodtracer.dx diff --git a/examples/lodtracer.dx b/examples/lodtracer.dx new file mode 100644 index 000000000..298704d62 --- /dev/null +++ b/examples/lodtracer.dx @@ -0,0 +1,532 @@ + +import plot + +'## General Utilies + +def pngsToSavedGif (delay:Int) (pngs:t=>Png) (outFileName:String) : Gif = + unsafeIO \(). + withTempFiles \pngFiles. + for i. + writeFile pngFiles.i pngs.i + shellOut $ + "convert" <> " -delay " <> show delay <> " " <> + concat (for i. "png:" <> pngFiles.i <> " ") <> + "gif:" <> outFileName <> ".gif" + + +def hue2rgb (p:Float) (q:Float) (t:Float) : Float = + t = t - floor t + if t < (1.0/6.0) + then p + (q - p) * 6.0 * t + else if t < (1.0/2.0) + then q + else if t < (2.0/3.0) + then p + (q - p) * (2.0/3.0 - t) * 6.0 + else p + +def hslToRgb (h:Float) (s:Float) (l:Float) : (Fin 3)=>Float = + if s == 0.0 + then [l, l, l] -- achromatic + else + q = select (l < 0.5) (l * (1.0 + s)) (l + s - l * s) + p = 2.0 * l - q + r = hue2rgb p q (h + 1.0/3.0) + g = hue2rgb p q h + b = hue2rgb p q (h - 1.0/3.0) + [r, g, b] + + +def W8ToB' (x : Word8) : Bool = x > (IToW8 0) -- Bug in prelude + +def firstbit (x:Word8) : Bool = W8ToB' $ x .&. (IToW8 1) +def secondbit (x:Word8) : Bool = W8ToB' $ x .&. (IToW8 2) +def thirdbit (x:Word8) : Bool = W8ToB' $ x .&. (IToW8 4) + +def last : n = ((size n) - 1)@n +def ixFraction (i:n) : Float = (IToF (ordinal i)) / (IToF ((size n) - 1)) + +-- :p for i:(Fin 10). ixFraction i +-- :p for i:(Fin 10). hslToRgb (ixFraction i) 1.0 0.5 +-- :p W8ToB (IToW8 4) + +def intpow2 (power:Int) : Int = %shl 1 power + + +'## Types + +def Vec (n:Int) : Type = Fin n => Float + +Position = Vec 3 +Direction = Vec 3 -- Should be normalized. TODO: use a newtype wrapper +Distance = Float + +Ray = (Position & Direction) + +data Axis = + X + Y + Z + +--Octants = Directions=>Bool +Octant = Fin 8 -- (Bool & Bool & Bool) -- positive in x, y, or z + +LoD = Fin 5 + +--def Voxel (lod:LoD) : Type = +-- (Fin lod)=>Octants + +--def Voxel (numDivisions:Int) : Type = +-- (Fin numDivisions & Fin numDivisions & Fin numDivisions) + +def Voxel : Type = + (Int & Int & Int) -- Sad! + +def AABox : Type = (Position & Position) -- opposite corners + + +'## Bounding box utilities + +def voxelWidth (lod: LoD) : Float = + -- voxels divide up cube in (-1, 1) + numDivisions = intpow2 $ ordinal lod + 2.0 / (IToF numDivisions) + +def voxelToCentrePosition (lod: LoD) (voxel: Voxel) : Position = + divSize = voxelWidth lod + (xn, yn, zn) = voxel + [divSize * (0.5 + IToF xn) - 1.0, + divSize * (0.5 + IToF yn) - 1.0, + divSize * (0.5 + IToF zn) - 1.0] + +def voxeltoBB (voxel: Voxel) (lod:LoD) : AABox = + divSize = voxelWidth lod + (xn, yn, zn) = voxel + lower = [divSize * (IToF xn) - 1.0, + divSize * (IToF yn) - 1.0, + divSize * (IToF zn) - 1.0] + upper = [divSize * (IToF (xn + 1)) - 1.0, + divSize * (IToF (yn + 1)) - 1.0, + divSize * (IToF (zn + 1)) - 1.0] + (lower, upper) + +def rayAABBIntersection (ray:Ray) (box:AABox) : Bool = + ([pos_x, pos_y, pos_z], [dir_x, dir_y, dir_z ]) = ray + ([low_x, low_y, low_z], [high_x, high_y, high_z]) = box + (tx1, tx2) = ((low_x - pos_x) / dir_x, (high_x - pos_x) / dir_x) + (ty1, ty2) = ((low_y - pos_y) / dir_y, (high_y - pos_y) / dir_y) + (tz1, tz2) = ((low_z - pos_z) / dir_z, (high_z - pos_z) / dir_z) + txn = min tx1 tx2 + txf = max tx1 tx2 + tyn = min ty1 ty2 + tyf = max ty1 ty2 + tzn = min tz1 tz2 + tzf = max tz1 tz2 + tnear = max txn $ max tyn tzn + tfar = min txf $ min tyf tzf + tfar > tnear + +def intersectsVoxel (lod: LoD) (ray:Ray) (voxel: Voxel) : Bool = + box = voxeltoBB voxel lod + rayAABBIntersection ray box + + +--def decide (lod: LOD) (ray:Ray) (voxel: Voxel) : (Fin 8)=>Bool = + -- Returned voxels are of the next level of detail. +-- if intersectsVoxel ray voxel then +-- (True, True, True, True, True, True, True, True) +-- else + + +'## Octtree intersection + +octantBackToFrontTable : Octant=>Octant=>Octant = + intTable = [[ 0, 1, 2, 4, 3, 5, 6, 7 ], + [ 1, 0, 3, 5, 2, 4, 7, 6 ], + [ 2, 0, 3, 6, 1, 4, 7, 5 ], + [ 3, 1, 2, 7, 0, 5, 6, 4 ], + [ 4, 0, 5, 6, 1, 2, 7, 3 ], + [ 5, 1, 4, 7, 0, 3, 6, 2 ], + [ 6, 2, 4, 7, 0, 3, 5, 1 ], + [ 7, 3, 5, 6, 1, 2, 4, 0 ]] + for i j. (intTable.i.j@Octant) + +def childOctantToSubVoxel (lod: LoD) (voxel: Voxel) (oct:Octant) : Voxel = + oo = IToW8 $ ordinal oct + (x, y, z) = voxel + (2 * x + (BToI $ firstbit oo), + 2 * y + (BToI $ secondbit oo), + 2 * z + (BToI $ thirdbit oo)) + +def subVoxelPosToChildOctant (lod: LoD) (voxel: Voxel) (pos:Position) : Octant = + [mid_x, mid_y, mid_z] = voxelToCentrePosition lod voxel + [pos_x, pos_y, pos_z] = pos + bit1 = BToW8 $ pos_x > mid_x + bit2 = BToW8 $ pos_y > mid_y + bit3 = BToW8 $ pos_z > mid_z + bits = W8ToI $ bit1 .|. (bit2 << 1) .|. (bit3 << 2) + bits@Octant + +def orderedChildren (lod: LoD) (ray:Ray) (voxel: Voxel) : Octant=>Octant = + -- todo: return only those children that intersect + (pos, dir) = ray + oct = subVoxelPosToChildOctant lod voxel pos + octantBackToFrontTable oct + +def subdivide (ray:Ray) (lod: LoD) ((AsList _ voxels) : List Voxel) : List Voxel = + -- Returned voxels are of the next level of detail. + -- Todo: Encode this in the types. + yieldAccum (ListMonoid Voxel) \list. + for t. + if intersectsVoxel lod ray voxels.t then + childOcts = orderedChildren lod ray voxels.t + list += AsList 8 for i:(Fin 8). childOctantToSubVoxel lod voxels.t childOcts.i + +def rayTraceOctTree (ray: Ray) : List Voxel = + -- Returns a depth-ordered list of voxels intersecting a ray. + top_voxel = (0, 0, 0) + update = subdivide ray + init = AsList 1 [top_voxel] + fold init update + + +rayvoxels = rayTraceOctTree ([0.1, 0.1, 0.1], [1.0, 1.0, 1.0]) + +-- :p rayvoxels + + + +---- Tests + +--p = voxelToCentrePosition (2@LoD) (0, 2, 0) +-- :p subVoxelPosToChildOctant (1@LoD) (0, 2, 0) p + +:p rayAABBIntersection ([0., 0., 0.], [0.1, 0.1, 0.1]) ([-0.1, -0.1, -0.1], [0.1, 0.1, 0.1]) + + + +'## Raytracer for debugging + + + + +' ### Generic Helper Functions +Some of these should probably go in prelude. + +-- def Vec (n:Int) : Type = Fin n => Float +def Mat (n:Int) (m:Int) : Type = Fin n => Fin m => Float + +def relu (x:Float) : Float = max x 0.0 +def length (x: d=>Float) : Float = sqrt $ sum for i. sq x.i +-- TODO: make a newtype for normal vectors +def normalize (x: d=>Float) : d=>Float = x / (length x) +def directionAndLength (x: d=>Float) : (d=>Float & Float) = + l = length x + (x / (length x), l) + +def randuniform (lower:Float) (upper:Float) (k:Key) : Float = + lower + (rand k) * (upper - lower) + +def sampleAveraged [VSpace a] (sample:Key -> a) (n:Int) (k:Key) : a = + yieldState zero \total. + for i:(Fin n). + total := get total + sample (ixkey k i) / IToF n + +def positiveProjection (x:n=>Float) (y:n=>Float) : Bool = dot x y > 0.0 + +' ### 3D Helper Functions + +def cross (a:Vec 3) (b:Vec 3) : Vec 3 = + [a1, a2, a3] = a + [b1, b2, b3] = b + [a2 * b3 - a3 * b2, a3 * b1 - a1 * b3, a1 * b2 - a2 * b1] + +-- TODO: Use `data Color = Red | Green | Blue` and ADTs for index sets +data Image = + MkImage height:Int width:Int (Fin height => Fin width => Color) + +xHat : Vec 3 = [1., 0., 0.] +yHat : Vec 3 = [0., 1., 0.] +zHat : Vec 3 = [0., 0., 1.] + +Angle = Float -- angle in radians + +def rotateX (p:Vec 3) (angle:Angle) : Vec 3 = + c = cos angle + s = sin angle + [px, py, pz] = p + [px, c*py - s*pz, s*py + c*pz] + +def rotateY (p:Vec 3) (angle:Angle) : Vec 3 = + c = cos angle + s = sin angle + [px, py, pz] = p + [c*px + s*pz, py, - s*px+ c*pz] + +def rotateZ (p:Vec 3) (angle:Angle) : Vec 3 = + c = cos angle + s = sin angle + [px, py, pz] = p + [c*px - s*py, s*px+c*py, pz] + +def sampleCosineWeightedHemisphere (normal: Vec 3) (k:Key) : Vec 3 = + [k1, k2] = splitKey k + u1 = rand k1 + u2 = rand k2 + uu = normalize $ cross normal [0.0, 1.1, 1.1] + vv = cross uu normal + ra = sqrt u2 + rx = ra * cos (2.0 * pi * u1) + ry = ra * sin (2.0 * pi * u1) + rz = sqrt (1.0 - u2) + rr = (rx .* uu) + (ry .* vv) + (rz .* normal) + normalize rr + +' ### Raytracer + +BlockHalfWidths = Vec 3 +Radius = Float +Radiance = Color + +data ObjectGeom = + Wall Direction Distance + Block Position BlockHalfWidths Angle + Sphere Position Radius + +data Surface = + Matte Color + Mirror + +OrientedSurface = (Direction & Surface) + +data Object = + PassiveObject ObjectGeom Surface + -- position, half-width, intensity (assumed to point down) + Light Position Float Radiance + +Filter = Color + +-- TODO: use a record +-- num samples, num bounces, share seed? +Params = { numSamples : Int + & maxBounces : Int + & shareSeed : Bool } + +-- TODO: use a list instead, once they work +data Scene n:Type = MkScene (n=>Object) + +def sampleReflection ((nor, surf):OrientedSurface) ((pos, dir):Ray) (k:Key) : Ray = + newDir = case surf of + Matte _ -> sampleCosineWeightedHemisphere nor k + -- TODO: surely there's some change-of-solid-angle correction we need to + -- consider when reflecting off a curved surface. + Mirror -> dir - (2.0 * dot dir nor) .* nor + (pos, newDir) + +def probReflection ((nor, surf):OrientedSurface) (_:Ray) ((_, outRayDir):Ray) : Float = + case surf of + Matte _ -> relu $ dot nor outRayDir + Mirror -> 0.0 -- TODO: this should be a delta function of some sort + +def applyFilter (filter:Filter) (radiance:Radiance) : Radiance = + for i. filter.i * radiance.i + +def surfaceFilter (filter:Filter) (surf:Surface) : Filter = + case surf of + Matte color -> for i. filter.i * color.i + Mirror -> filter + +def sdObject (pos:Position) (obj:Object) : Distance = + case obj of + PassiveObject geom _ -> case geom of + Wall nor d -> d + dot nor pos + Block blockPos halfWidths angle -> + pos' = rotateY (pos - blockPos) angle + length $ for i. max ((abs pos'.i) - halfWidths.i) 0.0 + Sphere spherePos r -> + pos' = pos - spherePos + max (length pos' - r) 0.0 + Light squarePos hw _ -> + pos' = pos - squarePos + halfWidths = [hw, 0.01, hw] + length $ for i. max ((abs pos'.i) - halfWidths.i) 0.0 + +def sdScene (scene:Scene n) (pos:Position) : (Object & Distance) = + (MkScene objs) = scene + (i, d) = minimumBy snd $ for i. (i, sdObject pos objs.i) + (objs.i, d) + +def calcNormal (obj:Object) (pos:Position) : Direction = + normalize (grad (flip sdObject obj) pos) + +data RayMarchResult = + -- incident ray, surface normal, surface properties + HitObj Ray OrientedSurface + HitLight Radiance + -- Could refine with failure reason (beyond horizon, failed to converge etc) + HitNothing + +def raymarch (scene:Scene n) (ray:Ray) : RayMarchResult = + maxIters = 100 + tol = 0.01 + startLength = 10.0 * tol -- trying to escape the current surface + (rayOrigin, rayDir) = ray + withState (10.0 * tol) \rayLength. + boundedIter maxIters HitNothing \_. + rayPos = rayOrigin + get rayLength .* rayDir + (obj, d) = sdScene scene $ rayPos + -- 0.9 ensures we come close to the surface but don't touch it + rayLength := get rayLength + 0.9 * d + case d < tol of + False -> Continue + True -> + surfNorm = calcNormal obj rayPos + case positiveProjection rayDir surfNorm of + True -> + -- Oops, we didn't escape the surface we're leaving.. + -- (Is there a more standard way to do this?) + Continue + False -> + -- We made it! + Done $ case obj of + PassiveObject _ surf -> HitObj (rayPos, rayDir) (surfNorm, surf) + Light _ _ radiance -> HitLight radiance + +def rayDirectRadiance (scene:Scene n) (ray:Ray) : Radiance = + case raymarch scene ray of + HitLight intensity -> intensity + HitNothing -> zero + HitObj _ _ -> zero + +def sampleSquare (hw:Float) (k:Key) : Position = + [kx, kz] = splitKey k + x = randuniform (- hw) hw kx + z = randuniform (- hw) hw kz + [x, 0.0, z] + +def sampleLightRadiance + (scene:Scene n) (osurf:OrientedSurface) (inRay:Ray) (k:Key) : Radiance = + (surfNor, surf) = osurf + (rayPos, _) = inRay + (MkScene objs) = scene + yieldAccum (AddMonoid Float) \radiance. + for i. case objs.i of + PassiveObject _ _ -> () + Light lightPos hw _ -> + (dirToLight, distToLight) = directionAndLength $ + lightPos + sampleSquare hw k - rayPos + if positiveProjection dirToLight surfNor then + -- light on this far side of current surface + fracSolidAngle = (relu $ dot dirToLight yHat) * sq hw / (pi * sq distToLight) + outRay = (rayPos, dirToLight) + coeff = fracSolidAngle * probReflection osurf inRay outRay + radiance += coeff .* rayDirectRadiance scene outRay + +def trace (params:Params) (scene:Scene n) (initRay:Ray) (k:Key) : Color = + noFilter = [1.0, 1.0, 1.0] + yieldAccum (AddMonoid Float) \radiance. + runState noFilter \filter. + runState initRay \ray. + boundedIter (getAt #maxBounces params) () \i. + case raymarch scene $ get ray of + HitNothing -> Done () + HitLight intensity -> + if i == 0 then radiance += intensity -- TODO: scale etc + Done () + HitObj incidentRay osurf -> + [k1, k2] = splitKey $ hash k i + lightRadiance = sampleLightRadiance scene osurf incidentRay k1 + ray := sampleReflection osurf incidentRay k2 + filter := surfaceFilter (get filter) (snd osurf) + radiance += applyFilter (get filter) lightRadiance + Continue + +-- Assumes we're looking towards -z. +Camera = + { numPix : Int + & pos : Position -- pinhole position + & halfWidth : Float -- sensor half-width + & sensorDist : Float } -- pinhole-sensor distance + +-- TODO: might be better with an anonymous dependent pair for the result +def cameraRays (n:Int) (camera:Camera) : Fin n => Fin n => (Key -> Ray) = + -- images indexed from top-left + halfWidth = getAt #halfWidth camera + pixHalfWidth = halfWidth / IToF n + ys = reverse $ linspace (Fin n) (neg halfWidth) halfWidth + xs = linspace (Fin n) (neg halfWidth) halfWidth + for i j. \key. + [kx, ky] = splitKey key + x = xs.j + randuniform (-pixHalfWidth) pixHalfWidth kx + y = ys.i + randuniform (-pixHalfWidth) pixHalfWidth ky + (getAt #pos camera, normalize [x, y, neg (getAt #sensorDist camera)]) + +def takePicture (params:Params) (scene:Scene m) (camera:Camera) : Image = + n = getAt #numPix camera + rays = cameraRays n camera + rootKey = newKey 0 + image = for i j. + pixKey = if getAt #shareSeed params + then rootKey + else ixkey (ixkey rootKey i) j + sampleRayColor : Key -> Color = \k. + [k1, k2] = splitKey k + trace params scene (rays.i.j k1) k2 + sampleAveraged sampleRayColor (getAt #numSamples params) pixKey + MkImage _ _ $ image / mean (for (i,j,k). image.i.j.k) + +' ### Define the scene and render it + +lightColor = [0.2, 0.2, 0.2] +leftWallColor = 1.5 .* [0.611, 0.0555, 0.062] +rightWallColor = 1.5 .* [0.117, 0.4125, 0.115] +whiteWallColor = [255.0, 239.0, 196.0] / 255.0 +blockColor = [200.0, 200.0, 255.0] / 255.0 + +objs1 = [ Light (1.9 .* yHat) 0.5 lightColor + , PassiveObject (Wall xHat 2.0) (Matte whiteWallColor) + , PassiveObject (Wall (neg xHat) 2.0) (Matte whiteWallColor) + , PassiveObject (Wall yHat 2.0) (Matte whiteWallColor) + , PassiveObject (Wall (neg yHat) 2.0) (Matte whiteWallColor) + , PassiveObject (Wall zHat 2.0) (Matte whiteWallColor) + -- , PassiveObject (Block [ 1.0, -1.6, 1.2] [0.6, 0.8, 0.6] 0.5) (Matte blockColor) + -- , PassiveObject (Sphere [-1.0, -1.2, 0.2] 0.8) (Matte (0.7.* whiteWallColor)) + -- , PassiveObject (Sphere [ 2.0, 2.0, -2.0] 1.5) (Mirror) + ] + +(AsList _ vl) = rayvoxels + +shrink = 0.45 +objs2 = for i. + cpos = shrink .* (voxelToCentrePosition last vl.i) + vw = shrink * (0.5 * voxelWidth last) + color = hslToRgb (ixFraction i) 1.0 0.5 + PassiveObject (Block cpos [vw, vw, vw] 0.0) (Matte color) + +combl = (AsList _ objs1) <> (AsList _ objs2) +(AsList _ cot) = combl +theScene = MkScene cot + +defaultParams = { numSamples = 1 + , maxBounces = 2 + , shareSeed = True } + +defaultCamera = { numPix = 400 + , pos = 10.0 .* zHat + , halfWidth = 0.3 + , sensorDist = 1.0 } + +-- We change to a small num pix here to reduce the compute needed for tests +params = defaultParams +camera = if dex_test_mode () + then defaultCamera |> setAt #numPix 10 + else defaultCamera + +%time +(MkImage _ _ image) = takePicture params theScene camera +:html imshow image + + +-- :html imseqshow xmovieflat +-- pngsToSavedGif 1 (map imgToPng xmovieflat) "gwg" + + From 3ae28df62709f99eb17d22c650980d6172ca0795 Mon Sep 17 00:00:00 2001 From: Cynthia Shen <58400192+cyntsh@users.noreply.github.com> Date: Fri, 10 Sep 2021 20:15:58 -0400 Subject: [PATCH 2/2] Load armadillo weights --- examples/lodtracer.dx | 929 ++++++++++++++++++++---------------------- 1 file changed, 451 insertions(+), 478 deletions(-) diff --git a/examples/lodtracer.dx b/examples/lodtracer.dx index 298704d62..fd208add7 100644 --- a/examples/lodtracer.dx +++ b/examples/lodtracer.dx @@ -1,115 +1,320 @@ +'# Neural implicit rendering +Adapted from [NGLOD](https://nv-tlabs.github.io/nglod/) by Towaki Takikawa et al., +which encodes geometry implicitly in a small neural network. +Since NN computations are expensive, we enhance the basic +sphere-tracing algorithm with an octree ADT to cut them down. import plot -'## General Utilies - -def pngsToSavedGif (delay:Int) (pngs:t=>Png) (outFileName:String) : Gif = - unsafeIO \(). - withTempFiles \pngFiles. - for i. - writeFile pngFiles.i pngs.i - shellOut $ - "convert" <> " -delay " <> show delay <> " " <> - concat (for i. "png:" <> pngFiles.i <> " ") <> - "gif:" <> outFileName <> ".gif" - - -def hue2rgb (p:Float) (q:Float) (t:Float) : Float = - t = t - floor t - if t < (1.0/6.0) - then p + (q - p) * 6.0 * t - else if t < (1.0/2.0) - then q - else if t < (2.0/3.0) - then p + (q - p) * (2.0/3.0 - t) * 6.0 - else p - -def hslToRgb (h:Float) (s:Float) (l:Float) : (Fin 3)=>Float = - if s == 0.0 - then [l, l, l] -- achromatic - else - q = select (l < 0.5) (l * (1.0 + s)) (l + s - l * s) - p = 2.0 * l - q - r = hue2rgb p q (h + 1.0/3.0) - g = hue2rgb p q h - b = hue2rgb p q (h - 1.0/3.0) - [r, g, b] +'### Types +Height = Fin 600 +Width = Fin 600 -def W8ToB' (x : Word8) : Bool = x > (IToW8 0) -- Bug in prelude - -def firstbit (x:Word8) : Bool = W8ToB' $ x .&. (IToW8 1) -def secondbit (x:Word8) : Bool = W8ToB' $ x .&. (IToW8 2) -def thirdbit (x:Word8) : Bool = W8ToB' $ x .&. (IToW8 4) - -def last : n = ((size n) - 1)@n -def ixFraction (i:n) : Float = (IToF (ordinal i)) / (IToF ((size n) - 1)) - --- :p for i:(Fin 10). ixFraction i --- :p for i:(Fin 10). hslToRgb (ixFraction i) 1.0 0.5 --- :p W8ToB (IToW8 4) - -def intpow2 (power:Int) : Int = %shl 1 power - - -'## Types - -def Vec (n:Int) : Type = Fin n => Float - -Position = Vec 3 -Direction = Vec 3 -- Should be normalized. TODO: use a newtype wrapper +def Vec (n: Int) : Type = Fin n => Float +Position = Vec 3 +Direction = Vec 3 Distance = Float - Ray = (Position & Direction) - -data Axis = - X - Y - Z - ---Octants = Directions=>Bool -Octant = Fin 8 -- (Bool & Bool & Bool) -- positive in x, y, or z - +Voxel = (Float & Float & Float) LoD = Fin 5 +Octant = Fin 8 +def AABox : Type = (Position & Position) -- opposite corners ---def Voxel (lod:LoD) : Type = --- (Fin lod)=>Octants +top_voxel = (-1., -1., -1.) +top_size = 2. ---def Voxel (numDivisions:Int) : Type = --- (Fin numDivisions & Fin numDivisions & Fin numDivisions) +fdim = 32 +FDim = Fin fdim -def Voxel : Type = - (Int & Int & Int) -- Sad! +fsize = 4 -def AABox : Type = (Position & Position) -- opposite corners +Input = Fin (fdim + 3) +HiddenLayer = Fin 128 +InputWeights = HiddenLayer => Input => Float +OutputWeights = Fin 1 => HiddenLayer => Float +Biases = HiddenLayer => Float +OutputBiases = Fin 1 => Float +'### General Utilies -'## Bounding box utilities +def W8ToB' (x : Word8) : Bool = x > (IToW8 0) -- Bug in prelude -def voxelWidth (lod: LoD) : Float = +def firstbit (x:Word8) : Bool = W8ToB' $ x .&. (IToW8 1) +def secondbit (x:Word8) : Bool = W8ToB' $ x .&. (IToW8 2) +def thirdbit (x:Word8) : Bool = W8ToB' $ x .&. (IToW8 4) + +def lastLoD : LoD = ((size LoD) - 1)@LoD + +def vecLength (v: Vec 3) : Distance = + [x, y, z] = v + sqrt $ sum for i. sq v.i + +'### Loading NN parameters + +def unsafeSlice (xs: n => a) (start:Int) (m: Type) : m => a = + for i. xs.(unsafeFromOrdinal _ (ordinal i + start)) + +def mkFeatureVol (pts: Fin b => c) (m: Type) : FDim => m => m => m => c = + vol = for d2:(FDim). + for d3:m. + for d4:m. + for d5:m. + pts.((ordinal (d2, d3, d4, d5))@_) + vol + +def mkWeights (x: m=>v) : n=>o=>v = + for i j. x.(ordinal (i,j)@_) + +def mkBiases (x: m=>v) : n=>v = + for i. x.(ordinal i@_) + +def loadFloats (f: FilePath): List Float32 = + (AsList byteCount bytesArray) = unsafeIO do readFile f + floatCount = idiv byteCount 4 + floatArray: Fin floatCount => Float32 = unsafeIO do + withTabPtr bytesArray \ptr. + tabFromPtr (Fin floatCount) $ castPtr ptr + AsList floatCount floatArray + +def getIndex (power:Int) : Int = + yieldAccum (AddMonoid Int) \sum. + for i:(Fin power). + r = intpow2 (ordinal i) + fsizei = fsize * r + flatsize = fdim * (fsizei+1) * (fsizei+1) * (fsizei+1) + sum += flatsize + +def getSize (power:Int) : Int = fsize * (intpow2 power) + 1 + +loaded_features = loadFloats "examples/features.bin" +(AsList numf features) = loaded_features + +def query_feature_vol (lod: Int) (m: Type): FDim => m => m => m => Float = + level = lod + j = getIndex level + flatsizei = fdim * (size m) * (size m) * (size m) + featuresi = unsafeSlice features j (Fin flatsizei) + feature_vol = mkFeatureVol featuresi m + feature_vol + +lod0 = getSize 0 +lod1 = getSize 1 +lod2 = getSize 2 +lod3 = getSize 3 +lod4 = getSize 4 + +f0 = query_feature_vol 0 (Fin lod0) +f1 = query_feature_vol 1 (Fin lod1) +f2 = query_feature_vol 2 (Fin lod2) +f3 = query_feature_vol 3 (Fin lod3) +f4 = query_feature_vol 4 (Fin lod4) + +loaded_params = loadFloats "examples/weights.bin" +(AsList nump params) = loaded_params + +sInputWeights = (size HiddenLayer) * (size Input) +sBiases = size HiddenLayer +sOutputWeights = size HiddenLayer +sOutputBiases = 1 +numParam = (sInputWeights) + (sBiases) + (sOutputWeights) + (sOutputBiases) + +def query_input_weights (lod: Int) : InputWeights = + j = lod * numParam + weights = unsafeSlice params j (Fin sInputWeights) + iw: InputWeights = mkWeights weights + iw + +def query_biases (lod: Int) : Biases = + offset = sInputWeights + j = lod * numParam + offset + biases = unsafeSlice params j (Fin sBiases) + b: Biases = mkBiases biases + b + +def query_output_weights (lod: Int) : OutputWeights = + offset = sInputWeights + sBiases + j = lod * numParam + offset + weights = unsafeSlice params j (Fin sOutputWeights) + iw: OutputWeights = mkWeights weights + iw + +def query_output_biases (lod: Int) : OutputBiases = + offset = sInputWeights + sBiases + sOutputWeights + j = lod * numParam + offset + biases = unsafeSlice params j (Fin sOutputBiases) + b: OutputBiases = mkBiases biases + b + +w1 = query_input_weights (ordinal lastLoD) +b1 = query_biases (ordinal lastLoD) +w2 = query_output_weights (ordinal lastLoD) +b2 = query_output_biases (ordinal lastLoD) + +'### Interpolation + +def safe_get (input: FDim => (Fin a) => (Fin a) => (Fin a) => Float) + (x: Float) (y: Float) (z: Float) + (fsize: Float) (dim: Int): Float = + case (x >= 0. && x < fsize && y >= 0. && y < fsize && z >= 0. && z < fsize) of + True -> input.(dim@_).((FToI z)@_).((FToI y)@_).((FToI x)@_) + False -> 0. + +def clip_coordinates (input: Float) (clip_limit: Float) : Float = + min (clip_limit - 1.) (max input 0.) + +-- implements torch.nn.Functional.grid_sample from https://github.com/pytorch/pytorch/blob/f064c5aa33483061a48994608d890b968ae53fb5/aten/src/THNN/generic/VolumetricGridSamplerBilinear.c +def trilinear_interp (feature_vol: FDim => (Fin a) => (Fin a) => (Fin a) => Float) + (p: Position) : FDim => Float = + fsize = IToF a + -- normalize position from [-1, 1] to [0, fsize-1] + x = ((p.(0@_) + 1.) / 2.) * (fsize - 1.) + y = ((p.(1@_) + 1.) / 2.) * (fsize - 1.) + z = ((p.(2@_) + 1.) / 2.) * (fsize - 1.) + -- corner values from (x, y, z) : north-east-south-west-top-bottom + ix_tnw = floor x + iy_tnw = floor y + iz_tnw = floor z + ix_tne = ix_tnw + 1.0 + iy_tne = iy_tnw + iz_tne = iz_tnw + ix_tsw = ix_tnw + iy_tsw = iy_tnw + 1.0 + iz_tsw = iz_tnw + ix_tse = ix_tnw + 1.0 + iy_tse = iy_tnw + 1.0 + iz_tse = iz_tnw + ix_bnw = ix_tnw + iy_bnw = iy_tnw + iz_bnw = iz_tnw + 1.0 + ix_bne = ix_tnw + 1.0 + iy_bne = iy_tnw + iz_bne = iz_tnw + 1.0 + ix_bsw = ix_tnw + iy_bsw = iy_tnw + 1.0 + iz_bsw = iz_tnw + 1.0 + ix_bse = ix_tnw + 1.0 + iy_bse = iy_tnw + 1.0 + iz_bse = iz_tnw + 1.0 + -- weighted corner values from (x, y, z) + tnw = (ix_bse - x) * (iy_bse - y) * (iz_bse - z) + tne = (x - ix_bsw) * (iy_bsw - y) * (iz_bsw - z) + tsw = (ix_bne - x) * (y - iy_bne) * (iz_bne - z) + tse = (x - ix_bnw) * (y - iy_bnw) * (iz_bnw - z) + bnw = (ix_tse - x) * (iy_tse - y) * (z - iz_tse) + bne = (x - ix_tsw) * (iy_tsw - y) * (z - iz_tsw) + bsw = (ix_tne - x) * (y - iy_tne) * (z - iz_tne) + bse = (x - ix_tnw) * (y - iy_tnw) * (z - iz_tnw) + -- pad out-of-bound grid locations with border values + x_tnw = clip_coordinates ix_tnw fsize + y_tnw = clip_coordinates iy_tnw fsize + z_tnw = clip_coordinates iz_tnw fsize + x_tne = clip_coordinates ix_tne fsize + y_tne = clip_coordinates iy_tne fsize + z_tne = clip_coordinates iz_tne fsize + x_tsw = clip_coordinates ix_tsw fsize + y_tsw = clip_coordinates iy_tsw fsize + z_tsw = clip_coordinates iz_tsw fsize + x_tse = clip_coordinates ix_tse fsize + y_tse = clip_coordinates iy_tse fsize + z_tse = clip_coordinates iz_tse fsize + x_bnw = clip_coordinates ix_bnw fsize + y_bnw = clip_coordinates iy_bnw fsize + z_bnw = clip_coordinates iz_bnw fsize + x_bne = clip_coordinates ix_bne fsize + y_bne = clip_coordinates iy_bne fsize + z_bne = clip_coordinates iz_bne fsize + x_bsw = clip_coordinates ix_bsw fsize + y_bsw = clip_coordinates iy_bsw fsize + z_bsw = clip_coordinates iz_bsw fsize + x_bse = clip_coordinates ix_bse fsize + y_bse = clip_coordinates iy_bse fsize + z_bse = clip_coordinates iz_bse fsize + for i:(FDim). + tnw_val = safe_get feature_vol x_tnw y_tnw z_tnw fsize (ordinal i) + tne_val = safe_get feature_vol x_tne y_tne z_tne fsize (ordinal i) + tsw_val = safe_get feature_vol x_tsw y_tsw z_tsw fsize (ordinal i) + tse_val = safe_get feature_vol x_tse y_tse z_tse fsize (ordinal i) + bnw_val = safe_get feature_vol x_bnw y_bnw z_bnw fsize (ordinal i) + bne_val = safe_get feature_vol x_bne y_bne z_bne fsize (ordinal i) + bsw_val = safe_get feature_vol x_bsw y_bsw z_bsw fsize (ordinal i) + bse_val = safe_get feature_vol x_bse y_bse z_bse fsize (ordinal i) + tnw_val * tnw + tne_val * tne + tsw_val * tsw + tse_val * tse + bnw_val * bnw + bne_val * bne + bsw_val * bsw + bse_val * bse + +'### Neural network + +@noinline +def relu (input : Float) : Float = + select (input > 0.0) input 0.0 + +def dilloMLP (p: Position) : Distance = + f = (trilinear_interp f0 p) + (trilinear_interp f1 p) + (trilinear_interp f2 p) + (trilinear_interp f3 p) + (trilinear_interp f4 p) + comb_input = (AsList _ p) <> (AsList _ f) + (AsList _ cinput) = comb_input + input = for i:Input. cinput.(unsafeFromOrdinal _ (ordinal i)) + l1 = w1 **. input + b1 + l2 = for i. relu l1.i + l3 = w2 **. l2 + b2 + l3.(0@_) + +def finitediff (x: Position) : Direction = + min_dist = 1.0/(64.0 * 3.0) + eps_x = [min_dist, 0.0, 0.0] + eps_y = [0.0, min_dist, 0.0] + eps_z = [0.0, 0.0, min_dist] + inputs = [(x + eps_x), (x - eps_x), (x + eps_y), (x - eps_y), (x + eps_z), (x - eps_z)] + lst = yieldAccum (ListMonoid Float) \list. + for i:(Fin 6). + list += AsList _ [(dilloMLP inputs.i) / (min_dist*2.0)] + (AsList _ g) = lst + [g.(0@_) - g.(1@_), g.(2@_) - g.(3@_), g.(4@_) - g.(5@_)] + +'### Voxel utilities + +def voxelWidth (lod: Int) : Float = -- voxels divide up cube in (-1, 1) - numDivisions = intpow2 $ ordinal lod - 2.0 / (IToF numDivisions) + numDivisions = intpow2 $ lod + top_size / (IToF numDivisions) def voxelToCentrePosition (lod: LoD) (voxel: Voxel) : Position = - divSize = voxelWidth lod + divSize = voxelWidth (ordinal lod) (xn, yn, zn) = voxel - [divSize * (0.5 + IToF xn) - 1.0, - divSize * (0.5 + IToF yn) - 1.0, - divSize * (0.5 + IToF zn) - 1.0] + [0.5*divSize + xn, + 0.5*divSize + yn, + 0.5*divSize + zn] -def voxeltoBB (voxel: Voxel) (lod:LoD) : AABox = - divSize = voxelWidth lod - (xn, yn, zn) = voxel - lower = [divSize * (IToF xn) - 1.0, - divSize * (IToF yn) - 1.0, - divSize * (IToF zn) - 1.0] - upper = [divSize * (IToF (xn + 1)) - 1.0, - divSize * (IToF (yn + 1)) - 1.0, - divSize * (IToF (zn + 1)) - 1.0] +octantBackToFrontTable : Octant => Octant => Octant = + intTable = [[ 0, 1, 2, 4, 3, 5, 6, 7 ], + [ 1, 0, 3, 5, 2, 4, 7, 6 ], + [ 2, 0, 3, 6, 1, 4, 7, 5 ], + [ 3, 1, 2, 7, 0, 5, 6, 4 ], + [ 4, 0, 5, 6, 1, 2, 7, 3 ], + [ 5, 1, 4, 7, 0, 3, 6, 2 ], + [ 6, 2, 4, 7, 0, 3, 5, 1 ], + [ 7, 3, 5, 6, 1, 2, 4, 0 ]] + for i j. (intTable.i.j@Octant) + +def subVoxelPosToChildOctant (lod: LoD) (voxel: Voxel) (pos:Position) : Octant = + [mid_x, mid_y, mid_z] = voxelToCentrePosition lod voxel + [pos_x, pos_y, pos_z] = pos + bit1 = BToW8 $ pos_x > mid_x + bit2 = BToW8 $ pos_y > mid_y + bit3 = BToW8 $ pos_z > mid_z + bits = W8ToI $ bit1 .|. (bit2 .<<. 1) .|. (bit3 .<<. 2) + bits@Octant + +'### Bounding box utilities + +def voxeltoBB (lod:Int) (voxel: Voxel) : AABox = + vwidth = voxelWidth lod + (x, y, z) = voxel + lower = [x, y, z] + upper = [x + vwidth, + y + vwidth, + z + vwidth] (lower, upper) -def rayAABBIntersection (ray:Ray) (box:AABox) : Bool = +def rayAABBIntersects (ray:Ray) (box:AABox) : Bool = ([pos_x, pos_y, pos_z], [dir_x, dir_y, dir_z ]) = ray ([low_x, low_y, low_z], [high_x, high_y, high_z]) = box (tx1, tx2) = ((low_x - pos_x) / dir_x, (high_x - pos_x) / dir_x) @@ -123,410 +328,178 @@ def rayAABBIntersection (ray:Ray) (box:AABox) : Bool = tzf = max tz1 tz2 tnear = max txn $ max tyn tzn tfar = min txf $ min tyf tzf - tfar > tnear + tfar > tnear && tfar > 0.0 -def intersectsVoxel (lod: LoD) (ray:Ray) (voxel: Voxel) : Bool = - box = voxeltoBB voxel lod +def rayAABBIntersection (ray: Ray) (box: AABox) : (Position & Position) = + ([pos_x, pos_y, pos_z], [dir_x, dir_y, dir_z ]) = ray + ([low_x, low_y, low_z], [high_x, high_y, high_z]) = box + (tx1, tx2) = case dir_x == 0.0 of + True -> (low_x, high_x) + False -> ((low_x - pos_x) / dir_x, (high_x - pos_x) / dir_x) + (ty1, ty2) = case dir_y == 0.0 of + True -> (low_y, high_y) + False -> ((low_y - pos_y) / dir_y, (high_y - pos_y) / dir_y) + (tz1, tz2) = case dir_z == 0.0 of + True -> (low_z, high_z) + False -> ((low_z - pos_z) / dir_z, (high_z - pos_z) / dir_z) + ([pos_x + tx1 * dir_x, pos_y + ty1 * dir_y, pos_z + tz1 * dir_z], + [pos_x + tx2 * dir_x, pos_y + ty2 * dir_y, pos_z + tz2 * dir_z]) + +def intersectsVoxel (ray: Ray) (lod: Int) (voxel: Voxel) : Bool = + box = voxeltoBB lod voxel + rayAABBIntersects ray box + +def intersection (ray: Ray) (lod: Int) (voxel: Voxel) : (Position & Position) = + box = voxeltoBB lod voxel rayAABBIntersection ray box +'### Octtree intersection ---def decide (lod: LOD) (ray:Ray) (voxel: Voxel) : (Fin 8)=>Bool = - -- Returned voxels are of the next level of detail. --- if intersectsVoxel ray voxel then --- (True, True, True, True, True, True, True, True) --- else - - -'## Octtree intersection - -octantBackToFrontTable : Octant=>Octant=>Octant = - intTable = [[ 0, 1, 2, 4, 3, 5, 6, 7 ], - [ 1, 0, 3, 5, 2, 4, 7, 6 ], - [ 2, 0, 3, 6, 1, 4, 7, 5 ], - [ 3, 1, 2, 7, 0, 5, 6, 4 ], - [ 4, 0, 5, 6, 1, 2, 7, 3 ], - [ 5, 1, 4, 7, 0, 3, 6, 2 ], - [ 6, 2, 4, 7, 0, 3, 5, 1 ], - [ 7, 3, 5, 6, 1, 2, 4, 0 ]] - for i j. (intTable.i.j@Octant) +def onSurface (lod: Int) (voxel: Voxel) : Bool = + centerPos = voxelToCentrePosition (lod@_) voxel + d = dilloMLP centerPos + w = voxelWidth lod + abs d < (sqrt(3.) * w /2.) -def childOctantToSubVoxel (lod: LoD) (voxel: Voxel) (oct:Octant) : Voxel = +def childOctantToSubVoxel (ray: Ray) (lod: Int) (oct: Octant) (voxel: Voxel) : (Voxel & Bool) = oo = IToW8 $ ordinal oct (x, y, z) = voxel - (2 * x + (BToI $ firstbit oo), - 2 * y + (BToI $ secondbit oo), - 2 * z + (BToI $ thirdbit oo)) - -def subVoxelPosToChildOctant (lod: LoD) (voxel: Voxel) (pos:Position) : Octant = - [mid_x, mid_y, mid_z] = voxelToCentrePosition lod voxel - [pos_x, pos_y, pos_z] = pos - bit1 = BToW8 $ pos_x > mid_x - bit2 = BToW8 $ pos_y > mid_y - bit3 = BToW8 $ pos_z > mid_z - bits = W8ToI $ bit1 .|. (bit2 << 1) .|. (bit3 << 2) - bits@Octant - -def orderedChildren (lod: LoD) (ray:Ray) (voxel: Voxel) : Octant=>Octant = - -- todo: return only those children that intersect - (pos, dir) = ray + vwidth = voxelWidth (lod + 1) + subvoxel = (x + vwidth*(BToF $ firstbit oo), + y + vwidth*(BToF $ secondbit oo), + z + vwidth*(BToF $ thirdbit oo)) + (subvoxel, intersectsVoxel ray (lod + 1) subvoxel) + +-- only called for subdividing the last LOD +def childOctantToLastSubVoxel (ray: Ray) (lod: Int) (oct: Octant) (voxel: Voxel) : (Voxel & Bool) = + oo = IToW8 $ ordinal oct + (x, y, z) = voxel + vwidth = voxelWidth (lod + 1) + subvoxelPos = (x + vwidth*(BToF $ firstbit oo), + y + vwidth*(BToF $ secondbit oo), + z + vwidth*(BToF $ thirdbit oo)) + intersects = intersectsVoxel ray (lod + 1) subvoxelPos + case intersects of + False -> (subvoxelPos, False) + True -> + case (onSurface lod voxel) of + False -> (subvoxelPos, False) + True -> (subvoxelPos, True) + +def intersectedVoxels (ray: Ray) (lod: LoD) (childOcts: Octant => Octant) (voxel: Voxel): List Voxel = + emptyList = AsList _ [] + yieldState emptyList \ref. + boundedIter 8 (get ref) \i. + subvoxel_pair = case (lod == lastLoD) of + False -> childOctantToSubVoxel ray (ordinal lod) childOcts.(i@_) voxel + True -> childOctantToLastSubVoxel ray (ordinal lod) childOcts.(i@_) voxel + if (snd subvoxel_pair) then + ref := concat [ (get ref), toList [(fst subvoxel_pair)] ] + Continue + +def orderedChildren (lod: LoD) (ray: Ray) (voxel: Voxel) : Octant=>Octant = + pos = fst $ intersection ray (ordinal lod) voxel oct = subVoxelPosToChildOctant lod voxel pos octantBackToFrontTable oct -def subdivide (ray:Ray) (lod: LoD) ((AsList _ voxels) : List Voxel) : List Voxel = - -- Returned voxels are of the next level of detail. - -- Todo: Encode this in the types. +def subdivide (ray: Ray) (lod: LoD) (voxelList: List Voxel) : List Voxel = + (AsList _ voxels) = voxelList yieldAccum (ListMonoid Voxel) \list. - for t. - if intersectsVoxel lod ray voxels.t then - childOcts = orderedChildren lod ray voxels.t - list += AsList 8 for i:(Fin 8). childOctantToSubVoxel lod voxels.t childOcts.i - -def rayTraceOctTree (ray: Ray) : List Voxel = - -- Returns a depth-ordered list of voxels intersecting a ray. - top_voxel = (0, 0, 0) - update = subdivide ray - init = AsList 1 [top_voxel] - fold init update - - -rayvoxels = rayTraceOctTree ([0.1, 0.1, 0.1], [1.0, 1.0, 1.0]) - --- :p rayvoxels - - - ----- Tests - ---p = voxelToCentrePosition (2@LoD) (0, 2, 0) --- :p subVoxelPosToChildOctant (1@LoD) (0, 2, 0) p - -:p rayAABBIntersection ([0., 0., 0.], [0.1, 0.1, 0.1]) ([-0.1, -0.1, -0.1], [0.1, 0.1, 0.1]) - - - -'## Raytracer for debugging + for i. + childOcts = orderedChildren lod ray voxels.i + intersected = intersectedVoxels ray lod childOcts voxels.i + list += intersected - - - -' ### Generic Helper Functions -Some of these should probably go in prelude. - --- def Vec (n:Int) : Type = Fin n => Float -def Mat (n:Int) (m:Int) : Type = Fin n => Fin m => Float - -def relu (x:Float) : Float = max x 0.0 -def length (x: d=>Float) : Float = sqrt $ sum for i. sq x.i --- TODO: make a newtype for normal vectors -def normalize (x: d=>Float) : d=>Float = x / (length x) -def directionAndLength (x: d=>Float) : (d=>Float & Float) = - l = length x - (x / (length x), l) - -def randuniform (lower:Float) (upper:Float) (k:Key) : Float = - lower + (rand k) * (upper - lower) - -def sampleAveraged [VSpace a] (sample:Key -> a) (n:Int) (k:Key) : a = - yieldState zero \total. - for i:(Fin n). - total := get total + sample (ixkey k i) / IToF n - -def positiveProjection (x:n=>Float) (y:n=>Float) : Bool = dot x y > 0.0 - -' ### 3D Helper Functions - -def cross (a:Vec 3) (b:Vec 3) : Vec 3 = - [a1, a2, a3] = a - [b1, b2, b3] = b - [a2 * b3 - a3 * b2, a3 * b1 - a1 * b3, a1 * b2 - a2 * b1] - --- TODO: Use `data Color = Red | Green | Blue` and ADTs for index sets -data Image = - MkImage height:Int width:Int (Fin height => Fin width => Color) - -xHat : Vec 3 = [1., 0., 0.] -yHat : Vec 3 = [0., 1., 0.] -zHat : Vec 3 = [0., 0., 1.] - -Angle = Float -- angle in radians - -def rotateX (p:Vec 3) (angle:Angle) : Vec 3 = - c = cos angle - s = sin angle - [px, py, pz] = p - [px, c*py - s*pz, s*py + c*pz] - -def rotateY (p:Vec 3) (angle:Angle) : Vec 3 = - c = cos angle - s = sin angle - [px, py, pz] = p - [c*px + s*pz, py, - s*px+ c*pz] - -def rotateZ (p:Vec 3) (angle:Angle) : Vec 3 = - c = cos angle - s = sin angle - [px, py, pz] = p - [c*px - s*py, s*px+c*py, pz] - -def sampleCosineWeightedHemisphere (normal: Vec 3) (k:Key) : Vec 3 = - [k1, k2] = splitKey k - u1 = rand k1 - u2 = rand k2 - uu = normalize $ cross normal [0.0, 1.1, 1.1] - vv = cross uu normal - ra = sqrt u2 - rx = ra * cos (2.0 * pi * u1) - ry = ra * sin (2.0 * pi * u1) - rz = sqrt (1.0 - u2) - rr = (rx .* uu) + (ry .* vv) + (rz .* normal) - normalize rr - -' ### Raytracer - -BlockHalfWidths = Vec 3 -Radius = Float -Radiance = Color - -data ObjectGeom = - Wall Direction Distance - Block Position BlockHalfWidths Angle - Sphere Position Radius - -data Surface = - Matte Color - Mirror - -OrientedSurface = (Direction & Surface) - -data Object = - PassiveObject ObjectGeom Surface - -- position, half-width, intensity (assumed to point down) - Light Position Float Radiance - -Filter = Color - --- TODO: use a record --- num samples, num bounces, share seed? -Params = { numSamples : Int - & maxBounces : Int - & shareSeed : Bool } - --- TODO: use a list instead, once they work -data Scene n:Type = MkScene (n=>Object) - -def sampleReflection ((nor, surf):OrientedSurface) ((pos, dir):Ray) (k:Key) : Ray = - newDir = case surf of - Matte _ -> sampleCosineWeightedHemisphere nor k - -- TODO: surely there's some change-of-solid-angle correction we need to - -- consider when reflecting off a curved surface. - Mirror -> dir - (2.0 * dot dir nor) .* nor - (pos, newDir) - -def probReflection ((nor, surf):OrientedSurface) (_:Ray) ((_, outRayDir):Ray) : Float = - case surf of - Matte _ -> relu $ dot nor outRayDir - Mirror -> 0.0 -- TODO: this should be a delta function of some sort - -def applyFilter (filter:Filter) (radiance:Radiance) : Radiance = - for i. filter.i * radiance.i - -def surfaceFilter (filter:Filter) (surf:Surface) : Filter = - case surf of - Matte color -> for i. filter.i * color.i - Mirror -> filter - -def sdObject (pos:Position) (obj:Object) : Distance = - case obj of - PassiveObject geom _ -> case geom of - Wall nor d -> d + dot nor pos - Block blockPos halfWidths angle -> - pos' = rotateY (pos - blockPos) angle - length $ for i. max ((abs pos'.i) - halfWidths.i) 0.0 - Sphere spherePos r -> - pos' = pos - spherePos - max (length pos' - r) 0.0 - Light squarePos hw _ -> - pos' = pos - squarePos - halfWidths = [hw, 0.01, hw] - length $ for i. max ((abs pos'.i) - halfWidths.i) 0.0 - -def sdScene (scene:Scene n) (pos:Position) : (Object & Distance) = - (MkScene objs) = scene - (i, d) = minimumBy snd $ for i. (i, sdObject pos objs.i) - (objs.i, d) - -def calcNormal (obj:Object) (pos:Position) : Direction = - normalize (grad (flip sdObject obj) pos) - -data RayMarchResult = - -- incident ray, surface normal, surface properties - HitObj Ray OrientedSurface - HitLight Radiance - -- Could refine with failure reason (beyond horizon, failed to converge etc) - HitNothing - -def raymarch (scene:Scene n) (ray:Ray) : RayMarchResult = - maxIters = 100 - tol = 0.01 - startLength = 10.0 * tol -- trying to escape the current surface +def rayOctreeIntersection (ray: Ray) : List Voxel = + init = AsList 1 [top_voxel] + voxelList = fst $ runState init \ref. for lod:LoD. + vlevel = get ref + ref := subdivide ray lod vlevel + vlevel + voxelList.lastLoD + +'### Ray-tracing + +def factorToDist (factor: Float) (dir: Direction): Distance = + factor * (vecLength dir) + +def distToFactor (d: Distance) (dir: Direction): Float = + d / (vecLength dir) + +def factorToPos (factor: Float) (orig: Position) (dir: Direction) : Position = + orig + factor .* dir + +def posToFactor (p: Position) (orig: Position) (dir: Direction) : Float = + dir' = (p - orig) + case dir.(0@_) == 0.0 of + False -> dir'.(0@_) / dir.(0@_) + True -> + case dir.(1@_) == 0.0 of + False -> dir'.(1@_) / dir.(1@_) + True -> + case dir.(2@_) == 0.0 of + False -> dir'.(2@_) / dir.(2@_) + True -> 1. + +def sphereTrace (ray: Ray) (lod: LoD) (voxelList: List Voxel) : (Position&Bool) = + maxIters = 200 + minDist = 0.0003 (rayOrigin, rayDir) = ray - withState (10.0 * tol) \rayLength. - boundedIter maxIters HitNothing \_. - rayPos = rayOrigin + get rayLength .* rayDir - (obj, d) = sdScene scene $ rayPos - -- 0.9 ensures we come close to the surface but don't touch it - rayLength := get rayLength + 0.9 * d - case d < tol of + (AsList numVoxels vl) = voxelList + final = yieldState (0.001, False) \t. + boundedIter numVoxels (0.0, False) \i. + voxPos = vl.(i@_) + (nrayVoxPos, frayVoxPos) = intersection ray (ordinal lod) voxPos -- returns (near, far) + t := ((posToFactor nrayVoxPos rayOrigin rayDir), snd (get t)) + + -- sphere tracing iterations + t := boundedIter maxIters (0.0, False) \_. + newDist = dilloMLP (factorToPos (fst (get t)) rayOrigin rayDir) + currDist = factorToDist (fst (get t)) rayDir + accumDist = currDist + newDist + t' = distToFactor accumDist rayDir + t := (t', (snd (get t))) + voxSpan = vecLength (frayVoxPos - nrayVoxPos) -- the maximum before the ray passes through voxel + case newDist > minDist of + True -> + case newDist < voxSpan of + True -> Continue + False -> Done ((fst (get t)), False) + False -> Done ((fst (get t)), True) + + case (snd (get t)) == True of + True -> Done (get t) False -> Continue - True -> - surfNorm = calcNormal obj rayPos - case positiveProjection rayDir surfNorm of - True -> - -- Oops, we didn't escape the surface we're leaving.. - -- (Is there a more standard way to do this?) - Continue - False -> - -- We made it! - Done $ case obj of - PassiveObject _ surf -> HitObj (rayPos, rayDir) (surfNorm, surf) - Light _ _ radiance -> HitLight radiance - -def rayDirectRadiance (scene:Scene n) (ray:Ray) : Radiance = - case raymarch scene ray of - HitLight intensity -> intensity - HitNothing -> zero - HitObj _ _ -> zero - -def sampleSquare (hw:Float) (k:Key) : Position = - [kx, kz] = splitKey k - x = randuniform (- hw) hw kx - z = randuniform (- hw) hw kz - [x, 0.0, z] - -def sampleLightRadiance - (scene:Scene n) (osurf:OrientedSurface) (inRay:Ray) (k:Key) : Radiance = - (surfNor, surf) = osurf - (rayPos, _) = inRay - (MkScene objs) = scene - yieldAccum (AddMonoid Float) \radiance. - for i. case objs.i of - PassiveObject _ _ -> () - Light lightPos hw _ -> - (dirToLight, distToLight) = directionAndLength $ - lightPos + sampleSquare hw k - rayPos - if positiveProjection dirToLight surfNor then - -- light on this far side of current surface - fracSolidAngle = (relu $ dot dirToLight yHat) * sq hw / (pi * sq distToLight) - outRay = (rayPos, dirToLight) - coeff = fracSolidAngle * probReflection osurf inRay outRay - radiance += coeff .* rayDirectRadiance scene outRay - -def trace (params:Params) (scene:Scene n) (initRay:Ray) (k:Key) : Color = - noFilter = [1.0, 1.0, 1.0] - yieldAccum (AddMonoid Float) \radiance. - runState noFilter \filter. - runState initRay \ray. - boundedIter (getAt #maxBounces params) () \i. - case raymarch scene $ get ray of - HitNothing -> Done () - HitLight intensity -> - if i == 0 then radiance += intensity -- TODO: scale etc - Done () - HitObj incidentRay osurf -> - [k1, k2] = splitKey $ hash k i - lightRadiance = sampleLightRadiance scene osurf incidentRay k1 - ray := sampleReflection osurf incidentRay k2 - filter := surfaceFilter (get filter) (snd osurf) - radiance += applyFilter (get filter) lightRadiance - Continue - --- Assumes we're looking towards -z. -Camera = - { numPix : Int - & pos : Position -- pinhole position - & halfWidth : Float -- sensor half-width - & sensorDist : Float } -- pinhole-sensor distance - --- TODO: might be better with an anonymous dependent pair for the result -def cameraRays (n:Int) (camera:Camera) : Fin n => Fin n => (Key -> Ray) = - -- images indexed from top-left - halfWidth = getAt #halfWidth camera - pixHalfWidth = halfWidth / IToF n - ys = reverse $ linspace (Fin n) (neg halfWidth) halfWidth - xs = linspace (Fin n) (neg halfWidth) halfWidth - for i j. \key. - [kx, ky] = splitKey key - x = xs.j + randuniform (-pixHalfWidth) pixHalfWidth kx - y = ys.i + randuniform (-pixHalfWidth) pixHalfWidth ky - (getAt #pos camera, normalize [x, y, neg (getAt #sensorDist camera)]) - -def takePicture (params:Params) (scene:Scene m) (camera:Camera) : Image = - n = getAt #numPix camera - rays = cameraRays n camera - rootKey = newKey 0 - image = for i j. - pixKey = if getAt #shareSeed params - then rootKey - else ixkey (ixkey rootKey i) j - sampleRayColor : Key -> Color = \k. - [k1, k2] = splitKey k - trace params scene (rays.i.j k1) k2 - sampleAveraged sampleRayColor (getAt #numSamples params) pixKey - MkImage _ _ $ image / mean (for (i,j,k). image.i.j.k) - -' ### Define the scene and render it - -lightColor = [0.2, 0.2, 0.2] -leftWallColor = 1.5 .* [0.611, 0.0555, 0.062] -rightWallColor = 1.5 .* [0.117, 0.4125, 0.115] -whiteWallColor = [255.0, 239.0, 196.0] / 255.0 -blockColor = [200.0, 200.0, 255.0] / 255.0 - -objs1 = [ Light (1.9 .* yHat) 0.5 lightColor - , PassiveObject (Wall xHat 2.0) (Matte whiteWallColor) - , PassiveObject (Wall (neg xHat) 2.0) (Matte whiteWallColor) - , PassiveObject (Wall yHat 2.0) (Matte whiteWallColor) - , PassiveObject (Wall (neg yHat) 2.0) (Matte whiteWallColor) - , PassiveObject (Wall zHat 2.0) (Matte whiteWallColor) - -- , PassiveObject (Block [ 1.0, -1.6, 1.2] [0.6, 0.8, 0.6] 0.5) (Matte blockColor) - -- , PassiveObject (Sphere [-1.0, -1.2, 0.2] 0.8) (Matte (0.7.* whiteWallColor)) - -- , PassiveObject (Sphere [ 2.0, 2.0, -2.0] 1.5) (Mirror) - ] - -(AsList _ vl) = rayvoxels - -shrink = 0.45 -objs2 = for i. - cpos = shrink .* (voxelToCentrePosition last vl.i) - vw = shrink * (0.5 * voxelWidth last) - color = hslToRgb (ixFraction i) 1.0 0.5 - PassiveObject (Block cpos [vw, vw, vw] 0.0) (Matte color) - -combl = (AsList _ objs1) <> (AsList _ objs2) -(AsList _ cot) = combl -theScene = MkScene cot - -defaultParams = { numSamples = 1 - , maxBounces = 2 - , shareSeed = True } - -defaultCamera = { numPix = 400 - , pos = 10.0 .* zHat - , halfWidth = 0.3 - , sensorDist = 1.0 } - --- We change to a small num pix here to reduce the compute needed for tests -params = defaultParams -camera = if dex_test_mode () - then defaultCamera |> setAt #numPix 10 - else defaultCamera + ((rayOrigin + (fst final) .* rayDir), snd final) -%time -(MkImage _ _ image) = takePicture params theScene camera -:html imshow image +'### Ray-octree setup +%time +rayvoxels: Height => Width => (Ray & List Voxel) = for x:Height. for y:Width. + rx = 2. * (IToF (ordinal x)) / IToF (size Height) - 1.0 + ry = 2. * (IToF (ordinal y)) / IToF (size Width) - 1.0 + ray = ([rx, ry, -1.], [0., 0., 1.]) + (ray, rayOctreeIntersection ray) +> Compile time: 94.326 s +> Run time: 4242.358 s --- :html imseqshow xmovieflat --- pngsToSavedGif 1 (map imgToPng xmovieflat) "gwg" +%time +finalPos: Height => Width => (Position & Bool) = for x:Height. for y:Width. + sphereTrace (fst rayvoxels.x.y) lastLoD (snd rayvoxels.x.y) +> Compile time: 138.776 s +> Run time: 1530.132 s +'### Test results +%time +canvas: Height => Width => Color = for x:Height. for y:Width. + (pos, onSurface) = finalPos.x.y + c = case onSurface of + True -> finitediff pos + False -> [0., 0., 0.] + -1. .* c +> Compile time: 72.465 s +> Run time: 441.041 s + +:html imshow canvas +>