Skip to content

Commit 9cb2e22

Browse files
joeycarterrmoyard
andauthored
[MLIR] Add type checking in JVPOp::verifySymbolUses (#1020)
**Context:** Add type checking in the JVP MLIR operation like we have in the equivalent VJP operation. **Description of the Change:** In `JVPOp::verifySymbolUses()`, gather up the data types of the tangent parameters and check them one by one against the data types of the corresponding callee input types, in a similar manner as is currently done in `VJPOp::verifySymbolUses()`. **Benefits:** The goal is to avoid triggering [this assert](https://github.com/PennyLaneAI/catalyst/blob/6c0ed0b528119b78bc32172780350ff1bc760424/mlir/lib/Gradient/Utils/EinsumLinalgGeneric.cpp#L100) by doing the type checking earlier and printing a more descriptive error message. **Possible Drawbacks:** None. **Related GitHub Issues:** [sc-48792] --------- Co-authored-by: Romain Moyard <[email protected]>
1 parent e1e12fe commit 9cb2e22

File tree

3 files changed

+92
-0
lines changed

3 files changed

+92
-0
lines changed

doc/changelog.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,11 @@
290290
are no longer decomposed when using Catalyst, improving compilation & runtime performance.
291291
[(#955)](https://github.com/PennyLaneAI/catalyst/pull/955)
292292

293+
* Improve error messaging for `catalyst.jvp` when the callee input type and the tangent
294+
type are not compatible by performing type-checking at the MLIR level. Note that the
295+
equivalent type checking is already performed in `catalyst.vjp`.
296+
[(#1020)](https://github.com/PennyLaneAI/catalyst/pull/1020)
297+
293298
<h3>Breaking changes</h3>
294299

295300
* Return values of qjit-compiled functions that were previously `numpy.ndarray` are now of type
@@ -468,6 +473,7 @@
468473

469474
This release contains contributions from (in alphabetical order):
470475

476+
Joey Carter,
471477
Alessandro Cosentino,
472478
Lillian M. A. Frederiksen,
473479
Josh Izaac,

mlir/lib/Gradient/IR/GradientOps.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,28 @@ LogicalResult JVPOp::verifySymbolUses(SymbolTableCollection &symbolTable)
317317
}
318318
}
319319

320+
std::vector<Type> tanTypes;
321+
{
322+
auto tanOperands = OperandRange(
323+
this->operand_begin() + callee.getFunctionType().getNumInputs(), this->operand_end());
324+
for (auto c : tanOperands) {
325+
tanTypes.push_back(c.getType());
326+
}
327+
}
328+
329+
auto calleeInputTypes = callee.getFunctionType().getInputs();
330+
331+
// Check that callee inputs have the same types as tangent inputs
332+
for (size_t i = 0; i < tanTypes.size(); i++) {
333+
auto tanType = tanTypes[i];
334+
auto cIType = calleeInputTypes[i];
335+
if (tanType != cIType) {
336+
return this->emitOpError("callee input type does not match the tangent type")
337+
<< " callee input " << i << " was expected to be of type " << tanType
338+
<< " but got " << cIType;
339+
}
340+
}
341+
320342
return success();
321343
}
322344

mlir/test/Gradient/VerifierTest.mlir

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,37 @@ gradient.jvp "fd" @measure(%cst0) tangents(%cst1) : (tensor<2xf64>, tensor<2xf64
268268

269269
// -----
270270

271+
func.func @foo(%arg0: tensor<f64>) -> (tensor<f64>, tensor<f64>) {
272+
273+
%cst = stablehlo.constant dense<2.000000e+00> : tensor<f64>
274+
%0 = stablehlo.multiply %cst, %arg0 : tensor<f64>
275+
%1 = stablehlo.multiply %arg0, %arg0 : tensor<f64>
276+
return %0, %1 : tensor<f64>, tensor<f64>
277+
278+
}
279+
280+
%cst0 = arith.constant dense<1.0> : tensor<f64>
281+
%cst1 = arith.constant dense<1.0> : tensor<f64>
282+
gradient.jvp "auto" @foo(%cst0) tangents(%cst1) : (tensor<f64>, tensor<f64>) -> (tensor<f64>, tensor<f64>, tensor<f64>, tensor<f64>)
283+
284+
// -----
285+
286+
func.func @foo(%arg0: tensor<f64>) -> (tensor<f64>, tensor<f64>) {
287+
288+
%cst = stablehlo.constant dense<2.000000e+00> : tensor<f64>
289+
%0 = stablehlo.multiply %cst, %arg0 : tensor<f64>
290+
%1 = stablehlo.multiply %arg0, %arg0 : tensor<f64>
291+
return %0, %1 : tensor<f64>, tensor<f64>
292+
293+
}
294+
295+
%cst0 = arith.constant dense<1.0> : tensor<f64>
296+
%cst1 = arith.constant dense<1> : tensor<i64>
297+
// expected-error@+1 {{callee input type does not match the tangent type}}
298+
gradient.jvp "auto" @foo(%cst0) tangents(%cst1) : (tensor<f64>, tensor<i64>) -> (tensor<f64>, tensor<f64>, tensor<f64>, tensor<f64>)
299+
300+
// -----
301+
271302
func.func @measure(%arg0: tensor<2xf64>) -> tensor<2xf64> {
272303

273304
%c0 = arith.constant 0 : i64
@@ -303,6 +334,39 @@ gradient.vjp "fd" @measure(%cst0) cotangents(%cst1) {resultSegmentSizes = array<
303334

304335
// -----
305336

337+
func.func @foo(%arg0: tensor<f64>) -> (tensor<f64>, tensor<f64>) {
338+
339+
%cst = stablehlo.constant dense<2.000000e+00> : tensor<f64>
340+
%0 = stablehlo.multiply %cst, %arg0 : tensor<f64>
341+
%1 = stablehlo.multiply %arg0, %arg0 : tensor<f64>
342+
return %0, %1 : tensor<f64>, tensor<f64>
343+
344+
}
345+
346+
%cst0 = arith.constant dense<1.0> : tensor<f64>
347+
%cst1 = arith.constant dense<1.0> : tensor<f64>
348+
%cst2 = arith.constant dense<1.0> : tensor<f64>
349+
gradient.vjp "auto" @foo(%cst0) cotangents(%cst1, %cst2) {resultSegmentSizes = array<i32: 2, 1>} : (tensor<f64>, tensor<f64>, tensor<f64>) -> (tensor<f64>, tensor<f64>, tensor<f64>)
350+
351+
// -----
352+
353+
func.func @foo(%arg0: tensor<f64>) -> (tensor<f64>, tensor<f64>) {
354+
355+
%cst = stablehlo.constant dense<2.000000e+00> : tensor<f64>
356+
%0 = stablehlo.multiply %cst, %arg0 : tensor<f64>
357+
%1 = stablehlo.multiply %arg0, %arg0 : tensor<f64>
358+
return %0, %1 : tensor<f64>, tensor<f64>
359+
360+
}
361+
362+
%cst0 = arith.constant dense<1.0> : tensor<f64>
363+
%cst1 = arith.constant dense<1> : tensor<i64>
364+
%cst2 = arith.constant dense<1> : tensor<i64>
365+
// expected-error@+1 {{callee result type does not match the cotangent type}}
366+
gradient.vjp "auto" @foo(%cst0) cotangents(%cst1, %cst2) {resultSegmentSizes = array<i32: 2, 1>} : (tensor<f64>, tensor<i64>, tensor<i64>) -> (tensor<f64>, tensor<f64>, tensor<f64>)
367+
368+
// -----
369+
306370
module @grad.wrapper {
307371
func.func public @jit_grad.wrapper(%arg0: tensor<2xf64>) -> tensor<2xf64> attributes {llvm.emit_c_interface} {
308372
%0 = gradient.grad "auto" @wrapper(%arg0) {diffArgIndices = dense<0> : tensor<1xi64>} : (tensor<2xf64>) -> tensor<2xf64>

0 commit comments

Comments
 (0)