@@ -326,43 +326,41 @@ else
326
326
_tail (nt:: NamedTuple ) = Base. tail (nt)
327
327
end
328
328
329
- function subset (varinfo:: UntypedVarInfo , vns:: AbstractVector{<:VarName} )
329
+ function subset (varinfo:: VarInfo , vns:: AbstractVector{<:VarName} )
330
330
metadata = subset (varinfo. metadata, vns)
331
331
return VarInfo (metadata, deepcopy (varinfo. logp), deepcopy (varinfo. num_produce))
332
332
end
333
333
334
- function subset (varinfo:: VectorVarInfo , vns:: AbstractVector{<:VarName} )
335
- metadata = subset (varinfo. metadata, vns)
336
- return VarInfo (metadata, deepcopy (varinfo. logp), deepcopy (varinfo. num_produce))
334
+ function subset (metadata:: NamedTuple , vns:: AbstractVector{<:VarName} )
335
+ vns_syms = Set (unique (map (getsym, vns)))
336
+ syms = filter (Base. Fix2 (in, vns_syms), keys (metadata))
337
+ metadatas = map (syms) do sym
338
+ subset (getfield (metadata, sym), filter (== (sym) ∘ getsym, vns))
339
+ end
340
+ return NamedTuple {syms} (metadatas)
337
341
end
338
342
339
- function subset (varinfo:: TypedVarInfo , vns:: AbstractVector{<:VarName{sym}} ) where {sym}
340
- # If all the variables are using the same symbol, then we can just extract that field from the metadata.
341
- metadata = subset (getfield (varinfo. metadata, sym), vns)
342
- return VarInfo (
343
- NamedTuple {(sym,)} (tuple (metadata)),
344
- deepcopy (varinfo. logp),
345
- deepcopy (varinfo. num_produce),
346
- )
347
- end
343
+ # The above method is type unstable since we don't know which symbols are in `vns`.
344
+ # In the below special case, when all `vns` have the same symbol, we can write a type stable
345
+ # version.
348
346
349
- function subset (varinfo:: TypedVarInfo , vns:: AbstractVector{<:VarName} )
350
- syms = Tuple (unique (map (getsym, vns)))
351
- metadatas = map (syms) do sym
352
- subset (getfield (varinfo. metadata, sym), filter (== (sym) ∘ getsym, vns))
347
+ @generated function subset (
348
+ metadata:: NamedTuple{names} , vns:: AbstractVector{<:VarName{sym}}
349
+ ) where {names,sym}
350
+ return if (sym in names)
351
+ # TODO (mhauru) Note that this could still generate an empty metadata object if none
352
+ # of the lenses in `vns` are in `metadata`. Not sure if that's okay. Checking for
353
+ # emptiness would make this type unstable again.
354
+ :((; $ sym= subset (metadata.$ sym, vns)))
355
+ else
356
+ :(NamedTuple {} ())
353
357
end
354
-
355
- return VarInfo (
356
- NamedTuple {syms} (metadatas), deepcopy (varinfo. logp), deepcopy (varinfo. num_produce)
357
- )
358
358
end
359
359
360
360
function subset (metadata:: Metadata , vns_given:: AbstractVector{VN} ) where {VN<: VarName }
361
361
# TODO : Should we error if `vns` contains a variable that is not in `metadata`?
362
- # For each `vn` in `vns`, get the variables subsumed by `vn`.
363
- vns = mapreduce (vcat, vns_given; init= VN[]) do vn
364
- filter (Base. Fix1 (subsumes, vn), metadata. vns)
365
- end
362
+ # Find all the vns in metadata that are subsumed by one of the given vns.
363
+ vns = filter (vn -> any (subsumes (vn_given, vn) for vn_given in vns_given), metadata. vns)
366
364
indices_for_vns = map (Base. Fix1 (getindex, metadata. idcs), vns)
367
365
indices = if isempty (vns)
368
366
Dict {VarName,Int} ()
0 commit comments