|
1 | 1 | import torch |
2 | | -from .helper_functions import getIndexTensorAlongDim, comonLazy, eval_tensor_expr, make_zero_like |
| 2 | +from .helper_functions import generate_dim_variables, getIndexTensorAlongDim, comonLazy, eval_tensor_expr, make_zero_like |
3 | 3 |
|
4 | 4 | from comfy_api.latest import io |
5 | 5 |
|
|
9 | 9 | class AudioMathNode(MathNodeBase): |
10 | 10 | """ |
11 | 11 | Enables math expressions on Audio tensors. |
12 | | - |
| 12 | +
|
13 | 13 | Inputs: |
14 | 14 | a, b, c, d: Audio inputs (b, c, d default to zero if not provided) |
15 | 15 | w, x, y, z: Float variables for expressions |
16 | 16 | AudioExpr: Expression to apply on audio tensors |
17 | | - |
| 17 | +
|
18 | 18 | Outputs: |
19 | 19 | AUDIO: Result of applying expression to input audio |
20 | 20 | """ |
@@ -42,31 +42,32 @@ def define_schema(cls) -> io.Schema: |
42 | 42 | ) |
43 | 43 |
|
44 | 44 | @classmethod |
45 | | - def execute(cls, a, AudioExpr, b=None, c=None, d=None, w=0.0, x=0.0, y=0.0, z=0.0): |
46 | | - waveform = a['waveform'] |
| 45 | + def execute(cls, AudioExpr, a, b=None, c=None, d=None, w=0.0, x=0.0, y=0.0, z=0.0): |
| 46 | + av = a['waveform'] |
47 | 47 | sample_rate = a['sample_rate'] |
48 | 48 |
|
49 | 49 | a, b, c, d = cls.prepare_inputs(a, b, c, d) |
50 | 50 |
|
51 | 51 | bv, cv, dv = b['waveform'], c['waveform'], d['waveform'] |
52 | 52 |
|
53 | 53 | variables = { |
54 | | - 'a': waveform, 'b': bv, 'c': cv, 'd': dv, |
| 54 | + 'a': av, 'b': bv, 'c': cv, 'd': dv, |
55 | 55 | 'w': w, 'x': x, 'y': y, 'z': z, |
56 | | - 'B': getIndexTensorAlongDim(waveform, 0), |
57 | | - 'C': getIndexTensorAlongDim(waveform, 1), |
58 | | - 'S': getIndexTensorAlongDim(waveform, 2), |
59 | | - 'R': torch.full_like(waveform, sample_rate, dtype=torch.float32), |
60 | | - 'T': torch.full_like(waveform, waveform.shape[2], dtype=torch.float32), |
61 | | - 'N': waveform.shape[1], |
62 | | - 'batch': getIndexTensorAlongDim(waveform, 0), |
63 | | - 'channel': getIndexTensorAlongDim(waveform, 1), |
64 | | - 'sample': getIndexTensorAlongDim(waveform, 2), |
65 | | - 'sample_rate': torch.full_like(waveform, sample_rate, dtype=torch.float32), |
66 | | - 'sample_count': torch.full_like(waveform, waveform.shape[2], dtype=torch.float32), |
67 | | - 'channel_count': waveform.shape[1], |
68 | | - } |
| 56 | + 'B': getIndexTensorAlongDim(av, 0), |
| 57 | + 'C': getIndexTensorAlongDim(av, 1), |
| 58 | + 'S': getIndexTensorAlongDim(av, 2), |
| 59 | + 'R': torch.full_like(av, sample_rate, dtype=torch.float32), |
| 60 | + 'T': torch.full_like(av, av.shape[2], dtype=torch.float32), |
| 61 | + 'N': av.shape[1], |
| 62 | + 'batch': getIndexTensorAlongDim(av, 0), |
| 63 | + 'channel': getIndexTensorAlongDim(av, 1), |
| 64 | + 'sample': getIndexTensorAlongDim(av, 2), |
| 65 | + 'sample_rate': torch.full_like(av, sample_rate, dtype=torch.float32), |
| 66 | + 'sample_count': torch.full_like(av, av.shape[2], dtype=torch.float32), |
| 67 | + 'channel_count': av.shape[1], |
| 68 | + } | generate_dim_variables(av) |
| 69 | + |
69 | 70 |
|
70 | | - result_tensor = eval_tensor_expr(AudioExpr, variables, waveform.shape) |
| 71 | + result_tensor = eval_tensor_expr(AudioExpr, variables, av.shape) |
71 | 72 |
|
72 | 73 | return ({'waveform': result_tensor, 'sample_rate': sample_rate},) |
0 commit comments