Skip to content

Commit eb7ce1c

Browse files
dvdskosyvokon
andcommitted
Adds mode constant to Pad
We drop the check on the number of inputs. We could re-introduce it but it would need to be different for the supported modes. Co-authored-by: Oleksiy <[email protected]>
1 parent 6a8d538 commit eb7ce1c

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

candle-onnx/src/eval.rs

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,13 +1150,6 @@ fn simple_eval_(
11501150
let mode = get_attr_opt(node, "mode")?.unwrap_or("constant");
11511151
let data = get(&node.input[0])?;
11521152
let pads = get(&node.input[1])?;
1153-
if node.input.len() > 2 {
1154-
bail!(
1155-
"unsupported number of inputs {} for Pad node {:?}, expected 2",
1156-
node.input.len(),
1157-
node.name
1158-
);
1159-
}
11601153
if pads.rank() != 1 {
11611154
bail!("Pad expects 'pads' input to be 1D vector: {pads:?}");
11621155
}
@@ -1193,6 +1186,34 @@ fn simple_eval_(
11931186

11941187
values.insert(node.output[0].clone(), out);
11951188
}
1189+
"constant" => {
1190+
let value = if node.input.len() > 2 {
1191+
get(&node.input[2])?.to_vec0::<f32>()?
1192+
} else {
1193+
0.0
1194+
};
1195+
1196+
let mut out = data.clone();
1197+
for (axis, (pad_pre, pad_post)) in
1198+
pads_pre.iter().zip(pads_post).enumerate()
1199+
{
1200+
if *pad_pre == 0 && *pad_post == 0 {
1201+
continue;
1202+
}
1203+
1204+
let mut new_dims = out.dims().to_vec();
1205+
new_dims[axis] += (*pad_pre + *pad_post) as usize;
1206+
1207+
out = Tensor::full(value, new_dims, out.device())?.slice_scatter(
1208+
&out,
1209+
axis,
1210+
*pad_pre as usize,
1211+
)?;
1212+
}
1213+
1214+
values.insert(node.output[0].clone(), out);
1215+
}
1216+
11961217
_ => bail!(
11971218
"unsupported 'mode' value {mode:?} for Pad node {:?}",
11981219
node.name

0 commit comments

Comments
 (0)