Skip to content

Commit ed3aebb

Browse files
committed
adds ChromaticAbberation and Vignette
1 parent 94bdf40 commit ed3aebb

File tree

4 files changed

+394
-0
lines changed

4 files changed

+394
-0
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Both images have the workflow attached, and it is included so feel free to use i
1717
- Blend: Blends two images together with a variety of different modes
1818
- Blur: Applies a Gaussian blur to the input image, softening the details
1919
- CannyEdgeDetection: Applies Canny edge detection to the input image
20+
- Chromatic Aberration: Shifts the color channels in an image, creating a glitch aesthetic
2021
- ColorCorrect: Adjusts the color balance, temperature, hue, brightness, contrast, saturation, and gamma of an image
2122
- Dissolve: Creates a grainy blend of two images using random pixels based on a dissolve factor.
2223
- DodgeAndBurn: Adjusts image brightness using dodge and burn effects based on a mask and intensity.
@@ -28,6 +29,7 @@ Both images have the workflow attached, and it is included so feel free to use i
2829
- Quantize: Set and dither the amount of colors in an image from 0-256, reducing color information
2930
- Sharpen: Enhances the details in an image by applying a sharpening filter
3031
- Solarize: Inverts image colors based on a threshold for a striking, high-contrast effect
32+
- Vignette: Applies a vignette effect, putting the corners of the image in shadow
3133

3234
## Combine Nodes
3335

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import torch
2+
3+
class ChromaticAberration:
4+
def __init__(self):
5+
pass
6+
7+
@classmethod
8+
def INPUT_TYPES(s):
9+
return {
10+
"required": {
11+
"image": ("IMAGE",),
12+
"red_shift": ("INT", {
13+
"default": 0,
14+
"min": -20,
15+
"max": 20,
16+
"step": 1
17+
}),
18+
"red_direction": (["horizontal", "vertical"],),
19+
"green_shift": ("INT", {
20+
"default": 0,
21+
"min": -20,
22+
"max": 20,
23+
"step": 1
24+
}),
25+
"green_direction": (["horizontal", "vertical"],),
26+
"blue_shift": ("INT", {
27+
"default": 0,
28+
"min": -20,
29+
"max": 20,
30+
"step": 1
31+
}),
32+
"blue_direction": (["horizontal", "vertical"],),
33+
},
34+
}
35+
36+
RETURN_TYPES = ("IMAGE",)
37+
FUNCTION = "chromatic_aberration"
38+
39+
CATEGORY = "postprocessing"
40+
41+
def chromatic_aberration(self, image: torch.Tensor, red_shift: int, green_shift: int, blue_shift: int, red_direction: str, green_direction: str, blue_direction: str):
42+
def get_shift(direction, shift):
43+
shift = -shift if direction == 'vertical' else shift # invert vertical shift as otherwise positive actually shifts down
44+
return (shift, 0) if direction == 'vertical' else (0, shift)
45+
46+
x = image.permute(0, 3, 1, 2)
47+
shifts = [get_shift(direction, shift) for direction, shift in zip([red_direction, green_direction, blue_direction], [red_shift, green_shift, blue_shift])]
48+
channels = [torch.roll(x[:, i, :, :], shifts=shifts[i], dims=(1, 2)) for i in range(3)]
49+
50+
output = torch.stack(channels, dim=1)
51+
output = output.permute(0, 2, 3, 1)
52+
53+
return (output,)
54+
55+
NODE_CLASS_MAPPINGS = {
56+
"ChromaticAberration": ChromaticAberration
57+
}

post_processing/vignette.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import numpy as np
2+
import torch
3+
4+
class Vignette:
5+
def __init__(self):
6+
pass
7+
8+
@classmethod
9+
def INPUT_TYPES(s):
10+
return {
11+
"required": {
12+
"image": ("IMAGE",),
13+
"a": ("FLOAT", {
14+
"default": 0.0,
15+
"min": 0.0,
16+
"max": 10.0,
17+
"step": 1.0
18+
}),
19+
},
20+
}
21+
22+
RETURN_TYPES = ("IMAGE",)
23+
FUNCTION = "apply_vignette"
24+
25+
CATEGORY = "postprocessing"
26+
27+
def apply_vignette(self, image: torch.Tensor, vignette: float):
28+
if vignette == 0:
29+
return (image,)
30+
height, width, _ = image.shape[-3:]
31+
x = torch.linspace(-1, 1, width, device=image.device)
32+
y = torch.linspace(-1, 1, height, device=image.device)
33+
X, Y = torch.meshgrid(x, y, indexing="ij")
34+
radius = torch.sqrt(X ** 2 + Y ** 2)
35+
36+
# Map vignette strength from 0-10 to 1.800-0.800
37+
mapped_vignette_strength = 1.8 - (vignette - 1) * 0.1
38+
vignette = 1 - torch.clamp(radius / mapped_vignette_strength, 0, 1)
39+
vignette = vignette[..., None]
40+
41+
vignette_image = torch.clamp(image * vignette, 0, 1)
42+
43+
return (vignette_image,)
44+
45+
NODE_CLASS_MAPPINGS = {
46+
"Vignette": Vignette,
47+
}

0 commit comments

Comments
 (0)