Skip to content

Commit cfc2c0e

Browse files
dvdskosyvokon
andcommitted
Adds mode constant to Pad
We drop the check on the number of inputs. We could re-introduce it but it would need to be different for the supported modes. Co-authored-by: Oleksiy <[email protected]>
1 parent b11adf7 commit cfc2c0e

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

candle-onnx/src/eval.rs

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,13 +1146,6 @@ fn simple_eval_(
11461146
let mode = get_attr_opt(node, "mode")?.unwrap_or("constant");
11471147
let data = get(&node.input[0])?;
11481148
let pads = get(&node.input[1])?;
1149-
if node.input.len() > 2 {
1150-
bail!(
1151-
"unsupported number of inputs {} for Pad node {:?}, expected 2",
1152-
node.input.len(),
1153-
node.name
1154-
);
1155-
}
11561149
if pads.rank() != 1 {
11571150
bail!("Pad expects 'pads' input to be 1D vector: {pads:?}");
11581151
}
@@ -1189,6 +1182,34 @@ fn simple_eval_(
11891182

11901183
values.insert(node.output[0].clone(), out);
11911184
}
1185+
"constant" => {
1186+
let value = if node.input.len() > 2 {
1187+
get(&node.input[2])?.to_vec0::<f32>()?
1188+
} else {
1189+
0.0
1190+
};
1191+
1192+
let mut out = data.clone();
1193+
for (axis, (pad_pre, pad_post)) in
1194+
pads_pre.iter().zip(pads_post).enumerate()
1195+
{
1196+
if *pad_pre == 0 && *pad_post == 0 {
1197+
continue;
1198+
}
1199+
1200+
let mut new_dims = out.dims().to_vec();
1201+
new_dims[axis] += (*pad_pre + *pad_post) as usize;
1202+
1203+
out = Tensor::full(value, new_dims, out.device())?.slice_scatter(
1204+
&out,
1205+
axis,
1206+
*pad_pre as usize,
1207+
)?;
1208+
}
1209+
1210+
values.insert(node.output[0].clone(), out);
1211+
}
1212+
11921213
_ => bail!(
11931214
"unsupported 'mode' value {mode:?} for Pad node {:?}",
11941215
node.name

0 commit comments

Comments
 (0)