diff --git a/node-graph/nodes/raster/src/filter.rs b/node-graph/nodes/raster/src/filter.rs index 14b6d45bf3..eacbd9fb1c 100644 --- a/node-graph/nodes/raster/src/filter.rs +++ b/node-graph/nodes/raster/src/filter.rs @@ -6,7 +6,7 @@ use raster_types::Image; use raster_types::{Bitmap, BitmapMut}; use raster_types::{CPU, Raster}; -/// Blurs the image with a Gaussian or blur kernel filter. +/// Blurs the image with a Gaussian or box blur kernel filter. #[node_macro::node(category("Raster: Filter"))] async fn blur( _: impl Ctx, @@ -42,6 +42,36 @@ async fn blur( .collect() } +/// Applies a median filter to reduce noise while preserving edges. +#[node_macro::node(category("Raster: Filter"))] +async fn median_filter( + _: impl Ctx, + /// The image to be filtered. + image_frame: Table>, + /// The radius of the filter kernel. Larger values remove more noise but may blur fine details. + #[range((0., 50.))] + #[hard_min(0.)] + radius: PixelLength, +) -> Table> { + image_frame + .into_iter() + .map(|mut row| { + let image = row.element.clone(); + + // Apply median filter + let filtered_image = if radius < 0.5 { + // Minimum filter radius + image.clone() + } else { + Raster::new_cpu(median_filter_algorithm(image.into_data(), radius as u32)) + }; + + row.element = filtered_image; + row + }) + .collect() +} + // 1D gaussian kernel fn gaussian_kernel(radius: f64) -> Vec { // Given radius, compute the size of the kernel that's approximately three times the radius @@ -179,3 +209,56 @@ fn box_blur_algorithm(mut original_buffer: Image, radius: f64, gamma: boo y_axis } + +fn median_filter_algorithm(original_buffer: Image, radius: u32) -> Image { + let (width, height) = original_buffer.dimensions(); + let mut output = Image::new(width, height, Color::TRANSPARENT); + + // Pre-allocate and reuse buffers outside the loops to avoid repeated allocations. + let window_capacity = ((2 * radius + 1).pow(2)) as usize; + let mut r_vals: Vec = Vec::with_capacity(window_capacity); + let mut g_vals: Vec = Vec::with_capacity(window_capacity); + let mut b_vals: Vec = Vec::with_capacity(window_capacity); + let mut a_vals: Vec = Vec::with_capacity(window_capacity); + + for y in 0..height { + for x in 0..width { + r_vals.clear(); + g_vals.clear(); + b_vals.clear(); + a_vals.clear(); + + // Use saturating_add to avoid potential overflow in extreme cases + let y_max = y.saturating_add(radius).min(height - 1); + let x_max = x.saturating_add(radius).min(width - 1); + + for ny in y.saturating_sub(radius)..=y_max { + for nx in x.saturating_sub(radius)..=x_max { + if let Some(px) = original_buffer.get_pixel(nx, ny) { + r_vals.push(px.r()); + g_vals.push(px.g()); + b_vals.push(px.b()); + a_vals.push(px.a()); + } + } + } + + let r = median_quickselect(&mut r_vals); + let g = median_quickselect(&mut g_vals); + let b = median_quickselect(&mut b_vals); + let a = median_quickselect(&mut a_vals); + + output.set_pixel(x, y, Color::from_rgbaf32_unchecked(r, g, b, a)); + } + } + + output +} +/// Finds the median of a slice using quickselect for O(n) average case performance. +/// This is more efficient than sorting the entire slice which would be O(n log n). +fn median_quickselect(values: &mut [f32]) -> f32 { + let mid: usize = values.len() / 2; + // nth_unstable is like quickselect: average O(n) + // Use total_cmp for safe NaN handling instead of partial_cmp().unwrap() + *values.select_nth_unstable_by(mid, |a, b| a.total_cmp(b)).1 +}