Commit ad5a062
committed
Make the jaxpr for jnp.pad in "constant" mode more succinct.
Example before:
```
$ print(jax.jit(lambda x: jnp.pad(x, ((0, 0), (1, 0), (0, 1)), constant_values=7)).lower(jnp.ones((3,4,5))).as_text())
module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<3x4x5xf32>) -> (tensor<3x5x6xf32> {jax.result_info = ""}) {
%c = stablehlo.constant dense<7> : tensor<i32>
%0 = call @_pad(%arg0, %c) : (tensor<3x4x5xf32>, tensor<i32>) -> tensor<3x5x6xf32>
return %0 : tensor<3x5x6xf32>
}
func.func private @_pad(%arg0: tensor<3x4x5xf32>, %arg1: tensor<i32>) -> tensor<3x5x6xf32> {
%0 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<i32>) -> tensor<3x2xi32>
%1 = stablehlo.convert %0 : (tensor<3x2xi32>) -> tensor<3x2xf32>
%2 = stablehlo.slice %1 [0:1, 0:1] : (tensor<3x2xf32>) -> tensor<1x1xf32>
%3 = stablehlo.reshape %2 : (tensor<1x1xf32>) -> tensor<f32>
%4 = stablehlo.pad %arg0, %3, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xf32>
%5 = stablehlo.slice %1 [0:1, 1:2] : (tensor<3x2xf32>) -> tensor<1x1xf32>
%6 = stablehlo.reshape %5 : (tensor<1x1xf32>) -> tensor<f32>
%7 = stablehlo.pad %4, %6, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xf32>
%8 = stablehlo.slice %1 [1:2, 0:1] : (tensor<3x2xf32>) -> tensor<1x1xf32>
%9 = stablehlo.reshape %8 : (tensor<1x1xf32>) -> tensor<f32>
%10 = stablehlo.pad %7, %9, low = [0, 1, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x5x5xf32>
%11 = stablehlo.slice %1 [1:2, 1:2] : (tensor<3x2xf32>) -> tensor<1x1xf32>
%12 = stablehlo.reshape %11 : (tensor<1x1xf32>) -> tensor<f32>
%13 = stablehlo.pad %10, %12, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x5x5xf32>, tensor<f32>) -> tensor<3x5x5xf32>
%14 = stablehlo.slice %1 [2:3, 0:1] : (tensor<3x2xf32>) -> tensor<1x1xf32>
%15 = stablehlo.reshape %14 : (tensor<1x1xf32>) -> tensor<f32>
%16 = stablehlo.pad %13, %15, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x5x5xf32>, tensor<f32>) -> tensor<3x5x5xf32>
%17 = stablehlo.slice %1 [2:3, 1:2] : (tensor<3x2xf32>) -> tensor<1x1xf32>
%18 = stablehlo.reshape %17 : (tensor<1x1xf32>) -> tensor<f32>
%19 = stablehlo.pad %16, %18, low = [0, 0, 0], high = [0, 0, 1], interior = [0, 0, 0] : (tensor<3x5x5xf32>, tensor<f32>) -> tensor<3x5x6xf32>
return %19 : tensor<3x5x6xf32>
}
}
```
After:
```
module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<3x4x5xf32>) -> (tensor<3x5x6xf32> {jax.result_info = ""}) {
%c = stablehlo.constant dense<7> : tensor<i32>
%0 = call @_pad(%arg0, %c) : (tensor<3x4x5xf32>, tensor<i32>) -> tensor<3x5x6xf32>
return %0 : tensor<3x5x6xf32>
}
func.func private @_pad(%arg0: tensor<3x4x5xf32>, %arg1: tensor<i32>) -> tensor<3x5x6xf32> {
%0 = stablehlo.convert %arg1 : (tensor<i32>) -> tensor<f32>
%1 = stablehlo.pad %arg0, %0, low = [0, 1, 0], high = [0, 0, 1], interior = [0, 0, 0] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x5x6xf32>
return %1 : tensor<3x5x6xf32>
}
}
```1 parent aefe621 commit ad5a062
1 file changed
+27
-5
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4048 | 4048 | | |
4049 | 4049 | | |
4050 | 4050 | | |
4051 | | - | |
4052 | 4051 | | |
4053 | 4052 | | |
| 4053 | + | |
| 4054 | + | |
| 4055 | + | |
| 4056 | + | |
| 4057 | + | |
| 4058 | + | |
| 4059 | + | |
| 4060 | + | |
| 4061 | + | |
| 4062 | + | |
| 4063 | + | |
| 4064 | + | |
| 4065 | + | |
| 4066 | + | |
| 4067 | + | |
| 4068 | + | |
| 4069 | + | |
| 4070 | + | |
| 4071 | + | |
| 4072 | + | |
| 4073 | + | |
4054 | 4074 | | |
4055 | 4075 | | |
4056 | | - | |
4057 | | - | |
4058 | | - | |
4059 | | - | |
| 4076 | + | |
| 4077 | + | |
| 4078 | + | |
| 4079 | + | |
| 4080 | + | |
| 4081 | + | |
4060 | 4082 | | |
4061 | 4083 | | |
4062 | 4084 | | |
| |||
0 commit comments