diff --git a/examples/lodtracer.dx b/examples/lodtracer.dx new file mode 100644 index 000000000..fd208add7 --- /dev/null +++ b/examples/lodtracer.dx @@ -0,0 +1,505 @@ +'# Neural implicit rendering +Adapted from [NGLOD](https://nv-tlabs.github.io/nglod/) by Towaki Takikawa et al., +which encodes geometry implicitly in a small neural network. +Since NN computations are expensive, we enhance the basic +sphere-tracing algorithm with an octree ADT to cut them down. + +import plot + +'### Types + +Height = Fin 600 +Width = Fin 600 + +def Vec (n: Int) : Type = Fin n => Float +Position = Vec 3 +Direction = Vec 3 +Distance = Float +Ray = (Position & Direction) +Voxel = (Float & Float & Float) +LoD = Fin 5 +Octant = Fin 8 +def AABox : Type = (Position & Position) -- opposite corners + +top_voxel = (-1., -1., -1.) +top_size = 2. + +fdim = 32 +FDim = Fin fdim + +fsize = 4 + +Input = Fin (fdim + 3) +HiddenLayer = Fin 128 +InputWeights = HiddenLayer => Input => Float +OutputWeights = Fin 1 => HiddenLayer => Float +Biases = HiddenLayer => Float +OutputBiases = Fin 1 => Float + +'### General Utilies + +def W8ToB' (x : Word8) : Bool = x > (IToW8 0) -- Bug in prelude + +def firstbit (x:Word8) : Bool = W8ToB' $ x .&. (IToW8 1) +def secondbit (x:Word8) : Bool = W8ToB' $ x .&. (IToW8 2) +def thirdbit (x:Word8) : Bool = W8ToB' $ x .&. (IToW8 4) + +def lastLoD : LoD = ((size LoD) - 1)@LoD + +def vecLength (v: Vec 3) : Distance = + [x, y, z] = v + sqrt $ sum for i. sq v.i + +'### Loading NN parameters + +def unsafeSlice (xs: n => a) (start:Int) (m: Type) : m => a = + for i. xs.(unsafeFromOrdinal _ (ordinal i + start)) + +def mkFeatureVol (pts: Fin b => c) (m: Type) : FDim => m => m => m => c = + vol = for d2:(FDim). + for d3:m. + for d4:m. + for d5:m. + pts.((ordinal (d2, d3, d4, d5))@_) + vol + +def mkWeights (x: m=>v) : n=>o=>v = + for i j. x.(ordinal (i,j)@_) + +def mkBiases (x: m=>v) : n=>v = + for i. x.(ordinal i@_) + +def loadFloats (f: FilePath): List Float32 = + (AsList byteCount bytesArray) = unsafeIO do readFile f + floatCount = idiv byteCount 4 + floatArray: Fin floatCount => Float32 = unsafeIO do + withTabPtr bytesArray \ptr. + tabFromPtr (Fin floatCount) $ castPtr ptr + AsList floatCount floatArray + +def getIndex (power:Int) : Int = + yieldAccum (AddMonoid Int) \sum. + for i:(Fin power). + r = intpow2 (ordinal i) + fsizei = fsize * r + flatsize = fdim * (fsizei+1) * (fsizei+1) * (fsizei+1) + sum += flatsize + +def getSize (power:Int) : Int = fsize * (intpow2 power) + 1 + +loaded_features = loadFloats "examples/features.bin" +(AsList numf features) = loaded_features + +def query_feature_vol (lod: Int) (m: Type): FDim => m => m => m => Float = + level = lod + j = getIndex level + flatsizei = fdim * (size m) * (size m) * (size m) + featuresi = unsafeSlice features j (Fin flatsizei) + feature_vol = mkFeatureVol featuresi m + feature_vol + +lod0 = getSize 0 +lod1 = getSize 1 +lod2 = getSize 2 +lod3 = getSize 3 +lod4 = getSize 4 + +f0 = query_feature_vol 0 (Fin lod0) +f1 = query_feature_vol 1 (Fin lod1) +f2 = query_feature_vol 2 (Fin lod2) +f3 = query_feature_vol 3 (Fin lod3) +f4 = query_feature_vol 4 (Fin lod4) + +loaded_params = loadFloats "examples/weights.bin" +(AsList nump params) = loaded_params + +sInputWeights = (size HiddenLayer) * (size Input) +sBiases = size HiddenLayer +sOutputWeights = size HiddenLayer +sOutputBiases = 1 +numParam = (sInputWeights) + (sBiases) + (sOutputWeights) + (sOutputBiases) + +def query_input_weights (lod: Int) : InputWeights = + j = lod * numParam + weights = unsafeSlice params j (Fin sInputWeights) + iw: InputWeights = mkWeights weights + iw + +def query_biases (lod: Int) : Biases = + offset = sInputWeights + j = lod * numParam + offset + biases = unsafeSlice params j (Fin sBiases) + b: Biases = mkBiases biases + b + +def query_output_weights (lod: Int) : OutputWeights = + offset = sInputWeights + sBiases + j = lod * numParam + offset + weights = unsafeSlice params j (Fin sOutputWeights) + iw: OutputWeights = mkWeights weights + iw + +def query_output_biases (lod: Int) : OutputBiases = + offset = sInputWeights + sBiases + sOutputWeights + j = lod * numParam + offset + biases = unsafeSlice params j (Fin sOutputBiases) + b: OutputBiases = mkBiases biases + b + +w1 = query_input_weights (ordinal lastLoD) +b1 = query_biases (ordinal lastLoD) +w2 = query_output_weights (ordinal lastLoD) +b2 = query_output_biases (ordinal lastLoD) + +'### Interpolation + +def safe_get (input: FDim => (Fin a) => (Fin a) => (Fin a) => Float) + (x: Float) (y: Float) (z: Float) + (fsize: Float) (dim: Int): Float = + case (x >= 0. && x < fsize && y >= 0. && y < fsize && z >= 0. && z < fsize) of + True -> input.(dim@_).((FToI z)@_).((FToI y)@_).((FToI x)@_) + False -> 0. + +def clip_coordinates (input: Float) (clip_limit: Float) : Float = + min (clip_limit - 1.) (max input 0.) + +-- implements torch.nn.Functional.grid_sample from https://github.com/pytorch/pytorch/blob/f064c5aa33483061a48994608d890b968ae53fb5/aten/src/THNN/generic/VolumetricGridSamplerBilinear.c +def trilinear_interp (feature_vol: FDim => (Fin a) => (Fin a) => (Fin a) => Float) + (p: Position) : FDim => Float = + fsize = IToF a + -- normalize position from [-1, 1] to [0, fsize-1] + x = ((p.(0@_) + 1.) / 2.) * (fsize - 1.) + y = ((p.(1@_) + 1.) / 2.) * (fsize - 1.) + z = ((p.(2@_) + 1.) / 2.) * (fsize - 1.) + -- corner values from (x, y, z) : north-east-south-west-top-bottom + ix_tnw = floor x + iy_tnw = floor y + iz_tnw = floor z + ix_tne = ix_tnw + 1.0 + iy_tne = iy_tnw + iz_tne = iz_tnw + ix_tsw = ix_tnw + iy_tsw = iy_tnw + 1.0 + iz_tsw = iz_tnw + ix_tse = ix_tnw + 1.0 + iy_tse = iy_tnw + 1.0 + iz_tse = iz_tnw + ix_bnw = ix_tnw + iy_bnw = iy_tnw + iz_bnw = iz_tnw + 1.0 + ix_bne = ix_tnw + 1.0 + iy_bne = iy_tnw + iz_bne = iz_tnw + 1.0 + ix_bsw = ix_tnw + iy_bsw = iy_tnw + 1.0 + iz_bsw = iz_tnw + 1.0 + ix_bse = ix_tnw + 1.0 + iy_bse = iy_tnw + 1.0 + iz_bse = iz_tnw + 1.0 + -- weighted corner values from (x, y, z) + tnw = (ix_bse - x) * (iy_bse - y) * (iz_bse - z) + tne = (x - ix_bsw) * (iy_bsw - y) * (iz_bsw - z) + tsw = (ix_bne - x) * (y - iy_bne) * (iz_bne - z) + tse = (x - ix_bnw) * (y - iy_bnw) * (iz_bnw - z) + bnw = (ix_tse - x) * (iy_tse - y) * (z - iz_tse) + bne = (x - ix_tsw) * (iy_tsw - y) * (z - iz_tsw) + bsw = (ix_tne - x) * (y - iy_tne) * (z - iz_tne) + bse = (x - ix_tnw) * (y - iy_tnw) * (z - iz_tnw) + -- pad out-of-bound grid locations with border values + x_tnw = clip_coordinates ix_tnw fsize + y_tnw = clip_coordinates iy_tnw fsize + z_tnw = clip_coordinates iz_tnw fsize + x_tne = clip_coordinates ix_tne fsize + y_tne = clip_coordinates iy_tne fsize + z_tne = clip_coordinates iz_tne fsize + x_tsw = clip_coordinates ix_tsw fsize + y_tsw = clip_coordinates iy_tsw fsize + z_tsw = clip_coordinates iz_tsw fsize + x_tse = clip_coordinates ix_tse fsize + y_tse = clip_coordinates iy_tse fsize + z_tse = clip_coordinates iz_tse fsize + x_bnw = clip_coordinates ix_bnw fsize + y_bnw = clip_coordinates iy_bnw fsize + z_bnw = clip_coordinates iz_bnw fsize + x_bne = clip_coordinates ix_bne fsize + y_bne = clip_coordinates iy_bne fsize + z_bne = clip_coordinates iz_bne fsize + x_bsw = clip_coordinates ix_bsw fsize + y_bsw = clip_coordinates iy_bsw fsize + z_bsw = clip_coordinates iz_bsw fsize + x_bse = clip_coordinates ix_bse fsize + y_bse = clip_coordinates iy_bse fsize + z_bse = clip_coordinates iz_bse fsize + for i:(FDim). + tnw_val = safe_get feature_vol x_tnw y_tnw z_tnw fsize (ordinal i) + tne_val = safe_get feature_vol x_tne y_tne z_tne fsize (ordinal i) + tsw_val = safe_get feature_vol x_tsw y_tsw z_tsw fsize (ordinal i) + tse_val = safe_get feature_vol x_tse y_tse z_tse fsize (ordinal i) + bnw_val = safe_get feature_vol x_bnw y_bnw z_bnw fsize (ordinal i) + bne_val = safe_get feature_vol x_bne y_bne z_bne fsize (ordinal i) + bsw_val = safe_get feature_vol x_bsw y_bsw z_bsw fsize (ordinal i) + bse_val = safe_get feature_vol x_bse y_bse z_bse fsize (ordinal i) + tnw_val * tnw + tne_val * tne + tsw_val * tsw + tse_val * tse + bnw_val * bnw + bne_val * bne + bsw_val * bsw + bse_val * bse + +'### Neural network + +@noinline +def relu (input : Float) : Float = + select (input > 0.0) input 0.0 + +def dilloMLP (p: Position) : Distance = + f = (trilinear_interp f0 p) + (trilinear_interp f1 p) + (trilinear_interp f2 p) + (trilinear_interp f3 p) + (trilinear_interp f4 p) + comb_input = (AsList _ p) <> (AsList _ f) + (AsList _ cinput) = comb_input + input = for i:Input. cinput.(unsafeFromOrdinal _ (ordinal i)) + l1 = w1 **. input + b1 + l2 = for i. relu l1.i + l3 = w2 **. l2 + b2 + l3.(0@_) + +def finitediff (x: Position) : Direction = + min_dist = 1.0/(64.0 * 3.0) + eps_x = [min_dist, 0.0, 0.0] + eps_y = [0.0, min_dist, 0.0] + eps_z = [0.0, 0.0, min_dist] + inputs = [(x + eps_x), (x - eps_x), (x + eps_y), (x - eps_y), (x + eps_z), (x - eps_z)] + lst = yieldAccum (ListMonoid Float) \list. + for i:(Fin 6). + list += AsList _ [(dilloMLP inputs.i) / (min_dist*2.0)] + (AsList _ g) = lst + [g.(0@_) - g.(1@_), g.(2@_) - g.(3@_), g.(4@_) - g.(5@_)] + +'### Voxel utilities + +def voxelWidth (lod: Int) : Float = + -- voxels divide up cube in (-1, 1) + numDivisions = intpow2 $ lod + top_size / (IToF numDivisions) + +def voxelToCentrePosition (lod: LoD) (voxel: Voxel) : Position = + divSize = voxelWidth (ordinal lod) + (xn, yn, zn) = voxel + [0.5*divSize + xn, + 0.5*divSize + yn, + 0.5*divSize + zn] + +octantBackToFrontTable : Octant => Octant => Octant = + intTable = [[ 0, 1, 2, 4, 3, 5, 6, 7 ], + [ 1, 0, 3, 5, 2, 4, 7, 6 ], + [ 2, 0, 3, 6, 1, 4, 7, 5 ], + [ 3, 1, 2, 7, 0, 5, 6, 4 ], + [ 4, 0, 5, 6, 1, 2, 7, 3 ], + [ 5, 1, 4, 7, 0, 3, 6, 2 ], + [ 6, 2, 4, 7, 0, 3, 5, 1 ], + [ 7, 3, 5, 6, 1, 2, 4, 0 ]] + for i j. (intTable.i.j@Octant) + +def subVoxelPosToChildOctant (lod: LoD) (voxel: Voxel) (pos:Position) : Octant = + [mid_x, mid_y, mid_z] = voxelToCentrePosition lod voxel + [pos_x, pos_y, pos_z] = pos + bit1 = BToW8 $ pos_x > mid_x + bit2 = BToW8 $ pos_y > mid_y + bit3 = BToW8 $ pos_z > mid_z + bits = W8ToI $ bit1 .|. (bit2 .<<. 1) .|. (bit3 .<<. 2) + bits@Octant + +'### Bounding box utilities + +def voxeltoBB (lod:Int) (voxel: Voxel) : AABox = + vwidth = voxelWidth lod + (x, y, z) = voxel + lower = [x, y, z] + upper = [x + vwidth, + y + vwidth, + z + vwidth] + (lower, upper) + +def rayAABBIntersects (ray:Ray) (box:AABox) : Bool = + ([pos_x, pos_y, pos_z], [dir_x, dir_y, dir_z ]) = ray + ([low_x, low_y, low_z], [high_x, high_y, high_z]) = box + (tx1, tx2) = ((low_x - pos_x) / dir_x, (high_x - pos_x) / dir_x) + (ty1, ty2) = ((low_y - pos_y) / dir_y, (high_y - pos_y) / dir_y) + (tz1, tz2) = ((low_z - pos_z) / dir_z, (high_z - pos_z) / dir_z) + txn = min tx1 tx2 + txf = max tx1 tx2 + tyn = min ty1 ty2 + tyf = max ty1 ty2 + tzn = min tz1 tz2 + tzf = max tz1 tz2 + tnear = max txn $ max tyn tzn + tfar = min txf $ min tyf tzf + tfar > tnear && tfar > 0.0 + +def rayAABBIntersection (ray: Ray) (box: AABox) : (Position & Position) = + ([pos_x, pos_y, pos_z], [dir_x, dir_y, dir_z ]) = ray + ([low_x, low_y, low_z], [high_x, high_y, high_z]) = box + (tx1, tx2) = case dir_x == 0.0 of + True -> (low_x, high_x) + False -> ((low_x - pos_x) / dir_x, (high_x - pos_x) / dir_x) + (ty1, ty2) = case dir_y == 0.0 of + True -> (low_y, high_y) + False -> ((low_y - pos_y) / dir_y, (high_y - pos_y) / dir_y) + (tz1, tz2) = case dir_z == 0.0 of + True -> (low_z, high_z) + False -> ((low_z - pos_z) / dir_z, (high_z - pos_z) / dir_z) + ([pos_x + tx1 * dir_x, pos_y + ty1 * dir_y, pos_z + tz1 * dir_z], + [pos_x + tx2 * dir_x, pos_y + ty2 * dir_y, pos_z + tz2 * dir_z]) + +def intersectsVoxel (ray: Ray) (lod: Int) (voxel: Voxel) : Bool = + box = voxeltoBB lod voxel + rayAABBIntersects ray box + +def intersection (ray: Ray) (lod: Int) (voxel: Voxel) : (Position & Position) = + box = voxeltoBB lod voxel + rayAABBIntersection ray box + +'### Octtree intersection + +def onSurface (lod: Int) (voxel: Voxel) : Bool = + centerPos = voxelToCentrePosition (lod@_) voxel + d = dilloMLP centerPos + w = voxelWidth lod + abs d < (sqrt(3.) * w /2.) + +def childOctantToSubVoxel (ray: Ray) (lod: Int) (oct: Octant) (voxel: Voxel) : (Voxel & Bool) = + oo = IToW8 $ ordinal oct + (x, y, z) = voxel + vwidth = voxelWidth (lod + 1) + subvoxel = (x + vwidth*(BToF $ firstbit oo), + y + vwidth*(BToF $ secondbit oo), + z + vwidth*(BToF $ thirdbit oo)) + (subvoxel, intersectsVoxel ray (lod + 1) subvoxel) + +-- only called for subdividing the last LOD +def childOctantToLastSubVoxel (ray: Ray) (lod: Int) (oct: Octant) (voxel: Voxel) : (Voxel & Bool) = + oo = IToW8 $ ordinal oct + (x, y, z) = voxel + vwidth = voxelWidth (lod + 1) + subvoxelPos = (x + vwidth*(BToF $ firstbit oo), + y + vwidth*(BToF $ secondbit oo), + z + vwidth*(BToF $ thirdbit oo)) + intersects = intersectsVoxel ray (lod + 1) subvoxelPos + case intersects of + False -> (subvoxelPos, False) + True -> + case (onSurface lod voxel) of + False -> (subvoxelPos, False) + True -> (subvoxelPos, True) + +def intersectedVoxels (ray: Ray) (lod: LoD) (childOcts: Octant => Octant) (voxel: Voxel): List Voxel = + emptyList = AsList _ [] + yieldState emptyList \ref. + boundedIter 8 (get ref) \i. + subvoxel_pair = case (lod == lastLoD) of + False -> childOctantToSubVoxel ray (ordinal lod) childOcts.(i@_) voxel + True -> childOctantToLastSubVoxel ray (ordinal lod) childOcts.(i@_) voxel + if (snd subvoxel_pair) then + ref := concat [ (get ref), toList [(fst subvoxel_pair)] ] + Continue + +def orderedChildren (lod: LoD) (ray: Ray) (voxel: Voxel) : Octant=>Octant = + pos = fst $ intersection ray (ordinal lod) voxel + oct = subVoxelPosToChildOctant lod voxel pos + octantBackToFrontTable oct + +def subdivide (ray: Ray) (lod: LoD) (voxelList: List Voxel) : List Voxel = + (AsList _ voxels) = voxelList + yieldAccum (ListMonoid Voxel) \list. + for i. + childOcts = orderedChildren lod ray voxels.i + intersected = intersectedVoxels ray lod childOcts voxels.i + list += intersected + +def rayOctreeIntersection (ray: Ray) : List Voxel = + init = AsList 1 [top_voxel] + voxelList = fst $ runState init \ref. for lod:LoD. + vlevel = get ref + ref := subdivide ray lod vlevel + vlevel + voxelList.lastLoD + +'### Ray-tracing + +def factorToDist (factor: Float) (dir: Direction): Distance = + factor * (vecLength dir) + +def distToFactor (d: Distance) (dir: Direction): Float = + d / (vecLength dir) + +def factorToPos (factor: Float) (orig: Position) (dir: Direction) : Position = + orig + factor .* dir + +def posToFactor (p: Position) (orig: Position) (dir: Direction) : Float = + dir' = (p - orig) + case dir.(0@_) == 0.0 of + False -> dir'.(0@_) / dir.(0@_) + True -> + case dir.(1@_) == 0.0 of + False -> dir'.(1@_) / dir.(1@_) + True -> + case dir.(2@_) == 0.0 of + False -> dir'.(2@_) / dir.(2@_) + True -> 1. + +def sphereTrace (ray: Ray) (lod: LoD) (voxelList: List Voxel) : (Position&Bool) = + maxIters = 200 + minDist = 0.0003 + (rayOrigin, rayDir) = ray + (AsList numVoxels vl) = voxelList + final = yieldState (0.001, False) \t. + boundedIter numVoxels (0.0, False) \i. + voxPos = vl.(i@_) + (nrayVoxPos, frayVoxPos) = intersection ray (ordinal lod) voxPos -- returns (near, far) + t := ((posToFactor nrayVoxPos rayOrigin rayDir), snd (get t)) + + -- sphere tracing iterations + t := boundedIter maxIters (0.0, False) \_. + newDist = dilloMLP (factorToPos (fst (get t)) rayOrigin rayDir) + currDist = factorToDist (fst (get t)) rayDir + accumDist = currDist + newDist + t' = distToFactor accumDist rayDir + t := (t', (snd (get t))) + voxSpan = vecLength (frayVoxPos - nrayVoxPos) -- the maximum before the ray passes through voxel + case newDist > minDist of + True -> + case newDist < voxSpan of + True -> Continue + False -> Done ((fst (get t)), False) + False -> Done ((fst (get t)), True) + + case (snd (get t)) == True of + True -> Done (get t) + False -> Continue + ((rayOrigin + (fst final) .* rayDir), snd final) + +'### Ray-octree setup + +%time +rayvoxels: Height => Width => (Ray & List Voxel) = for x:Height. for y:Width. + rx = 2. * (IToF (ordinal x)) / IToF (size Height) - 1.0 + ry = 2. * (IToF (ordinal y)) / IToF (size Width) - 1.0 + ray = ([rx, ry, -1.], [0., 0., 1.]) + (ray, rayOctreeIntersection ray) +> Compile time: 94.326 s +> Run time: 4242.358 s + +%time +finalPos: Height => Width => (Position & Bool) = for x:Height. for y:Width. + sphereTrace (fst rayvoxels.x.y) lastLoD (snd rayvoxels.x.y) +> Compile time: 138.776 s +> Run time: 1530.132 s + +'### Test results + +%time +canvas: Height => Width => Color = for x:Height. for y:Width. + (pos, onSurface) = finalPos.x.y + c = case onSurface of + True -> finitediff pos + False -> [0., 0., 0.] + -1. .* c +> Compile time: 72.465 s +> Run time: 441.041 s + +:html imshow canvas +>