Skip to content

Commit d7351b3

Browse files
[compiler] Refactor BlockMatrix sparsity representations in type and lowering (#15163)
## Change Description The bulk of the changes are in the files `BlockMatrixType.scala`, `MatrixSparsity.scala`, and `LowerBlockMatrixIR.scala`. I suggest starting with those, with the summaries below. I acknowledge this is a big change containing a lot of non-trivial logic. Please do ask for explanations and clarifications. ### BlockMatrixType and MatrixSparsity `BlockMatrixType` has a `sparsity`, of type `BlockMatrixSparsity`, which is either dense or a set of present blocks. In this PR I've moved the sparsity representation to `MatrixSparsity`, which is a generic encoding of the sparsity pattern of a sparse matrix. For a `BlockMatrixType`, the sparsity will be a `nBlockRows` by `nBlockCols` `MatrixSparsity`. Besides moving `BlockMatrixSparsity` to `MatrixSparsity`, and removing references to blocks, the encoding of the sparsity is also streamlined. Before, `BlockMatrixSparsity` held an array of coordinates of present blocks, in an arbitrary order. Methods that need the present blocks in a particular order (usually column major) would need to sort them, and we also built a `Set` of present blocks to handle the `isPresent` query. Now, `MatrixSparsity` (in the sparse case) enforces that the array of present blocks is always in column major order. This simplifies some of the logic, and lets us handle `isPresent` with binary search (similarly for unions/intersections of sparsity, which now take advantage of the ordering). In addition, `MatrixSparsity` now knows its dimensions, which `BlockMatrixSparsity` did not. I've also gotten rid of all the complicated methods on the CSC (compressed sparse column) encoding (e.g. `transposeCSCSparsity`). Instead, I do all the transformations on the simple coordinate list encoding, and convert to CSC as late as possible. ### LowerBlockMatrixIR The changes in this file are mostly about a redesign of the `BMSContexts` class. `BMSContexts` bundles together a runtime array of context values (similar to the contexts for a `TableStage`), with a representation of the sparsity pattern, which conceptually maps each context value to the coordinates of its block. Before, the sparsity pattern was also encoded by runtime values. In the sparse case, this used runtime arrays `rowPos` and `rowIdx` using a CSC encoding. Working with this encoding created a lot of non-trivial runtime logic. But sparsity is completely known at compile time, so all this runtime logic was unnecessary. There are still cases where we need to work with a CSC encoding of the sparsity at runtime. These cases are now handled by the class `DynamicBMSContexts`, which is essentially the same as the old `BMSContexts`, but now has a minimal interface of three methods. Importantly, these methods are only about consuming a block matrix, not transforming to a new one (like transpose). The new `BMSContexts` now pairs a statically known `MatrixSparsity` with a `DynamicBMSContexts`. The big change from before is that all methods that produce a new `BMSContexts` now handle all the sparsity logic at compile time, and simply embed the new sparsity in the IR using literals. This typically requires us to also embed an array of ints mapping from old to new positions in the contexts array, which we use at runtime to reorder the contexts as needed. ## Security Assessment - This change cannot impact the Hail Batch instance as deployed by Broad Institute in GCP
1 parent 0e29386 commit d7351b3

File tree

17 files changed

+1150
-958
lines changed

17 files changed

+1150
-958
lines changed

hail/build.mill

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ trait HailModule extends ScalaModule with ScalafmtModule with ScalafixModule { o
113113
mvn"io.github.tanin47::scalafix-forbidden-symbol:1.0.0",
114114
)
115115

116+
override def depManagement: T[Seq[Dep]] = Task {
117+
Seq(Deps.collection_compat)
118+
}
119+
116120
override def javacOptions: T[Seq[String]] = Seq(
117121
"-Xlint:all",
118122
"-Werror",
@@ -350,7 +354,7 @@ trait RootHailModule extends CrossScalaModule with HailModule { outer =>
350354
override def defaultTask(): String = "generate"
351355

352356
override def mvnDeps = Seq(
353-
mvn"com.lihaoyi::mainargs:0.6.2",
357+
mvn"com.lihaoyi::mainargs:0.7.7",
354358
mvn"com.lihaoyi::os-lib:0.10.7",
355359
mvn"com.lihaoyi::sourcecode:0.4.2",
356360
)

hail/hail/src/is/hail/expr/ir/BlockMatrixIR.scala

Lines changed: 81 additions & 82 deletions
Large diffs are not rendered by default.

hail/hail/src/is/hail/expr/ir/BlockMatrixWriter.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import is.hail.expr.ir.defs.{MetadataWriter, Str, UUID4, WriteMetadata, WriteVal
88
import is.hail.expr.ir.lowering.{BlockMatrixStage2, LowererUnsupportedOperation}
99
import is.hail.io.{StreamBufferSpec, TypedCodecSpec}
1010
import is.hail.io.fs.FS
11-
import is.hail.linalg.{BlockMatrix, BlockMatrixMetadata}
11+
import is.hail.linalg.{BlockMatrix, BlockMatrixMetadata, MatrixSparsity}
1212
import is.hail.types.TypeWithRequiredness
1313
import is.hail.types.encoded.{EBlockMatrixNDArray, ENumpyBinaryNDArray, EType}
1414
import is.hail.types.virtual._
@@ -135,8 +135,11 @@ case class BlockMatrixNativeMetadataWriter(
135135
cb: EmitCodeBuilder,
136136
region: Value[Region],
137137
): Unit = {
138-
val metaHelper =
139-
BMMetadataHelper(path, typ.blockSize, typ.nRows, typ.nCols, typ.linearizedDefinedBlocks)
138+
val partIdxToBlockIdx = typ.sparsity match {
139+
case _: MatrixSparsity.Dense => None
140+
case x: MatrixSparsity.Sparse => Some(x.definedBlocksColMajorLinear)
141+
}
142+
val metaHelper = BMMetadataHelper(path, typ.blockSize, typ.nRows, typ.nCols, partIdxToBlockIdx)
140143

141144
val pc = writeAnnotations.getOrFatal(cb, "write annotations can't be missing!").asIndexable
142145
val partFiles = cb.newLocal[Array[String]]("partFiles")

hail/hail/src/is/hail/expr/ir/IR.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,8 @@ package defs {
512512
def <=(other: IR): IR = ApplyComparisonOp(LTEQ, self, other)
513513

514514
def >=(other: IR): IR = ApplyComparisonOp(GTEQ, self, other)
515+
516+
def log(messages: AnyRef*): IR = logIR(self, messages: _*)
515517
}
516518

517519
object ErrorIDs {

hail/hail/src/is/hail/expr/ir/MatrixWriter.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import is.hail.io.gen.{BgenWriter, ExportGen}
1414
import is.hail.io.index.StagedIndexWriter
1515
import is.hail.io.plink.{BitPacker, ExportPlink}
1616
import is.hail.io.vcf.{ExportVCF, TabixVCF}
17-
import is.hail.linalg.BlockMatrix
17+
import is.hail.linalg.{BlockMatrix, MatrixSparsity}
1818
import is.hail.rvd.{IndexSpec, RVDPartitioner, RVDSpecMaker}
1919
import is.hail.types._
2020
import is.hail.types.encoded.{EBaseStruct, EBlockMatrixNDArray, EType}
@@ -2340,7 +2340,7 @@ case class MatrixBlockMatrixWriter(
23402340

23412341
val countColumnsIR = ArrayLen(GetField(ts.getGlobals(), colsFieldName))
23422342
val numCols: Int = CompileAndEvaluate[Int](ctx, countColumnsIR)
2343-
val numBlockCols: Int = (numCols - 1) / blockSize + 1
2343+
val numBlockCols: Int = BlockMatrixType.numBlocks(numCols.toLong, blockSize)
23442344
val lastBlockNumCols = (numCols - 1) % blockSize + 1
23452345

23462346
val rowCountIR = ts.mapCollect("matrix_block_matrix_writer_partition_counts")(paritionIR =>
@@ -2353,7 +2353,7 @@ case class MatrixBlockMatrixWriter(
23532353
val inputPartStops = inputPartStartsPlusLast.tail
23542354

23552355
val numRows = inputPartStartsPlusLast.last
2356-
val numBlockRows: Int = (numRows.toInt - 1) / blockSize + 1
2356+
val numBlockRows: Int = BlockMatrixType.numBlocks(numRows, blockSize)
23572357

23582358
// Zip contexts with partition starts and ends
23592359
val zippedWithStarts = ts.mapContexts { oldContextsStream =>
@@ -2510,7 +2510,7 @@ case class MatrixBlockMatrixWriter(
25102510
numRows,
25112511
numCols.toLong,
25122512
blockSize,
2513-
BlockMatrixSparsity.dense,
2513+
MatrixSparsity.dense(numBlockRows, numBlockCols),
25142514
)
25152515
RelationalWriter.scoped(path, overwrite, None)(WriteMetadata(
25162516
flatPaths,

hail/hail/src/is/hail/expr/ir/functions/ArrayFunctions.scala

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,54 @@ object ArrayFunctions extends RegistryFunctions {
311311
ToArray(flatMapIR(ToStream(a))(ToStream(_)))
312312
}
313313

314+
/* Construct an array of length `len`, with the values in `elts` copied to the positions in
315+
* `indices` */
316+
registerSCode3t(
317+
"scatter",
318+
Array(tv("T")),
319+
TArray(tv("T")), // elts
320+
TArray(TInt32), // indices
321+
TInt32, // len
322+
TArray(tv("T")),
323+
(_, a, _, _) => PCanonicalArray(a.asInstanceOf[SContainer].elementType.storageType()).sType,
324+
) {
325+
case (
326+
er,
327+
cb,
328+
_,
329+
rt: SIndexablePointer,
330+
elts: SIndexableValue,
331+
indices: SIndexableValue,
332+
len: SInt32Value,
333+
errorID,
334+
) =>
335+
cb.if_(
336+
elts.loadLength.cne(indices.loadLength),
337+
cb._fatalWithError(errorID, "scatter: values and indices arrays have different lengths"),
338+
)
339+
cb.if_(
340+
elts.loadLength > len.value,
341+
cb._fatalWithError(errorID, "scatter: values array is larger than result length"),
342+
)
343+
val pt = rt.pType.asInstanceOf[PCanonicalArray]
344+
val (push, finish) =
345+
pt.constructFromIndicesUnsafe(cb, er.region, len.value, deepCopy = false)
346+
indices.forEachDefined(cb) { case (cb, pos, idx: SInt32Value) =>
347+
cb.if_(
348+
idx.value < 0 || idx.value >= len.value,
349+
cb._fatalWithError(
350+
errorID,
351+
"scatter: indices array contained index ",
352+
idx.value.toS,
353+
", which is greater than result length ",
354+
len.value.toS,
355+
),
356+
)
357+
push(cb, idx.value, elts.loadElement(cb, pos))
358+
}
359+
finish(cb)
360+
}
361+
314362
registerSCode4(
315363
"lowerBound",
316364
TArray(tv("T")),

hail/hail/src/is/hail/expr/ir/functions/Functions.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,31 @@ abstract class RegistryFunctions {
815815
case (r, cb, _, rt, Array(a1, a2, a3), errorID) => impl(r, cb, rt, a1, a2, a3, errorID)
816816
}
817817

818+
def registerSCode3t(
819+
name: String,
820+
typeParams: Array[Type],
821+
mt1: Type,
822+
mt2: Type,
823+
mt3: Type,
824+
rt: Type,
825+
pt: (Type, SType, SType, SType) => SType,
826+
)(
827+
impl: (
828+
EmitRegion,
829+
EmitCodeBuilder,
830+
Seq[Type],
831+
SType,
832+
SValue,
833+
SValue,
834+
SValue,
835+
Value[Int],
836+
) => SValue
837+
): Unit =
838+
registerSCode(name, Array(mt1, mt2, mt3), rt, unwrappedApply(pt), typeParams) {
839+
case (r, cb, typeParams, rt, Array(a1, a2, a3), errorID) =>
840+
impl(r, cb, typeParams, rt, a1, a2, a3, errorID)
841+
}
842+
818843
def registerSCode4(
819844
name: String,
820845
mt1: Type,

0 commit comments

Comments
 (0)