Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 6151176

Browse files
authored
Add class conditioning (#140)
* Add class conditioning ans tests Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]> * Add misssing test (#140) Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]> Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>
1 parent 3f72d37 commit 6151176

File tree

2 files changed

+60
-5
lines changed

2 files changed

+60
-5
lines changed

generative/networks/nets/diffusion_model_unet.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1439,6 +1439,8 @@ class DiffusionModelUNet(nn.Module):
14391439
with_conditioning: if True add spatial transformers to perform conditioning.
14401440
transformer_num_layers: number of layers of Transformer blocks to use.
14411441
cross_attention_dim: number of context dimensions to use.
1442+
num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
1443+
classes.
14421444
"""
14431445

14441446
def __init__(
@@ -1455,6 +1457,7 @@ def __init__(
14551457
with_conditioning: bool = False,
14561458
transformer_num_layers: int = 1,
14571459
cross_attention_dim: Optional[int] = None,
1460+
num_class_embeds: Optional[int] = None,
14581461
) -> None:
14591462
super().__init__()
14601463
if with_conditioning is True and cross_attention_dim is None:
@@ -1499,6 +1502,11 @@ def __init__(
14991502
nn.Linear(time_embed_dim, time_embed_dim),
15001503
)
15011504

1505+
# class embedding
1506+
self.num_class_embeds = num_class_embeds
1507+
if num_class_embeds is not None:
1508+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
1509+
15021510
# down
15031511
self.down_blocks = nn.ModuleList([])
15041512
output_channel = num_channels[0]
@@ -1591,37 +1599,46 @@ def forward(
15911599
x: torch.Tensor,
15921600
timesteps: torch.Tensor,
15931601
context: Optional[torch.Tensor] = None,
1602+
class_labels: Optional[torch.Tensor] = None,
15941603
) -> torch.Tensor:
15951604
"""
15961605
Args:
15971606
x: input tensor (N, C, SpatialDims).
15981607
timesteps: timestep tensor (N,).
15991608
context: context tensor (N, 1, ContextDim).
1609+
class_labels: context tensor (N, ).
16001610
"""
16011611
# 1. time
16021612
t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])
16031613
emb = self.time_embed(t_emb)
16041614

1605-
# 2. initial convolution
1615+
# 2. class
1616+
if self.num_class_embeds is not None:
1617+
if class_labels is None:
1618+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
1619+
class_emb = self.class_embedding(class_labels)
1620+
emb = emb + class_emb
1621+
1622+
# 3. initial convolution
16061623
h = self.conv_in(x)
16071624

1608-
# 3. down
1625+
# 4. down
16091626
down_block_res_samples: List[torch.Tensor] = [h]
16101627
for downsample_block in self.down_blocks:
16111628
h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context)
16121629
for residual in res_samples:
16131630
down_block_res_samples.append(residual)
16141631

1615-
# 4. mid
1632+
# 5. mid
16161633
h = self.middle_block(hidden_states=h, temb=emb, context=context)
16171634

1618-
# 5. up
1635+
# 6. up
16191636
for upsample_block in self.up_blocks:
16201637
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
16211638
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
16221639
h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context)
16231640

1624-
# 6. output block
1641+
# 7. output block
16251642
h = self.out(h)
16261643

16271644
return h

tests/test_diffusion_model_unet.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,44 @@ def test_with_conditioning_cross_attention_dim_none(self):
168168
norm_num_groups=8,
169169
)
170170

171+
def test_shape_conditioned_models_class_conditioning(self):
172+
net = DiffusionModelUNet(
173+
spatial_dims=2,
174+
in_channels=1,
175+
out_channels=1,
176+
num_res_blocks=1,
177+
num_channels=(8, 8, 8),
178+
attention_levels=(False, False, True),
179+
norm_num_groups=8,
180+
num_head_channels=8,
181+
num_class_embeds=2,
182+
)
183+
with eval_mode(net):
184+
result = net.forward(
185+
x=torch.rand((1, 1, 16, 32)),
186+
timesteps=torch.randint(0, 1000, (1,)).long(),
187+
class_labels=torch.randint(0, 2, (1,)).long(),
188+
)
189+
self.assertEqual(result.shape, (1, 1, 16, 32))
190+
191+
def test_conditioned_models_no_class_labels(self):
192+
with self.assertRaises(ValueError):
193+
net = DiffusionModelUNet(
194+
spatial_dims=2,
195+
in_channels=1,
196+
out_channels=1,
197+
num_res_blocks=1,
198+
num_channels=(8, 8, 8),
199+
attention_levels=(False, False, True),
200+
norm_num_groups=8,
201+
num_head_channels=8,
202+
num_class_embeds=2,
203+
)
204+
net.forward(
205+
x=torch.rand((1, 1, 16, 32)),
206+
timesteps=torch.randint(0, 1000, (1,)).long(),
207+
)
208+
171209
def test_script_unconditioned_2d_models(self):
172210
net = DiffusionModelUNet(
173211
spatial_dims=2,

0 commit comments

Comments
 (0)