Skip to content

mgd lossΒ #3

@arielsolomon

Description

@arielsolomon

Search before asking

  • I have searched the Ultralytics YOLO issues and found no similar bug report.

Ultralytics YOLO Component

Train

Bug

Hi, first of all thank you for your much needed work.
When I used the response-based distillation ('cwd' loss) that was a smooth sail. When changed it to feature-based loss ('mgd') I got an error relating to the expected feature BatchNorm layer expected:
Traceback (most recent call last):
File "/work/yv8_distillation/run_yv8_dist.py", line 7, in
student_model.train(
File "/work/yv8_distillation/ultralytics/engine/model.py", line 802, in train
self.trainer.train()
File "/work/yv8_distillation/ultralytics/engine/trainer.py", line 499, in train
self._do_train(world_size)
File "/work/yv8_distillation/ultralytics/engine/trainer.py", line 712, in _do_train
self.d_loss = distillation_loss.get_loss()
File "/work/yv8_distillation/ultralytics/engine/trainer.py", line 323, in get_loss
quant_loss = self.distill_loss_fn(y_s=self.student_outputs, y_t=self.teacher_outputs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/work/yv8_distillation/ultralytics/engine/trainer.py", line 211, in forward
t = self.norm1idx
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py", line 175, in forward
return F.batch_norm(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/functional.py", line 2509, in batch_norm
return torch.batch_norm(
RuntimeError: running_mean should contain 512 elements not 256
I am not sure where to change it. would appreciate your help.
Best,
Ariel

Environment

New https://pypi.org/project/ultralytics/8.3.28 available πŸ˜ƒ Update with 'pip install -U ultralytics'
Ultralytics 8.3.24 πŸš€ Python-3.10.13 torch-2.3.0+cu121 CUDA:0 (NVIDIA GeForce RTX 2080 Ti, 11012MiB)

Minimal Reproducible Example

class MGDLoss(nn.Module):
def init(self,
student_channels,
teacher_channels,
alpha_mgd=0.00002,
lambda_mgd=0.65,
):
super(MGDLoss, self).init()
self.alpha_mgd = alpha_mgd
self.lambda_mgd = lambda_mgd
device = 'cuda' if torch.cuda.is_available() else 'cpu'

    self.generation = nn.ModuleList([
        nn.Sequential(
            nn.Conv2d(channel, channel, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel, channel, kernel_size=3, padding=1)
        ).to(device) for channel in teacher_channels
    ])

def forward(self, y_s, y_t, layer=None):
    """Forward computation.
    Args:
        y_s (list): The student model prediction with
            shape (N, C, H, W) in list.
        y_t (list): The teacher model prediction with
            shape (N, C, H, W) in list.
    Return:
        torch.Tensor: The calculated loss value of all stages.
    """
    losses = []
    for idx, (s, t) in enumerate(zip(y_s, y_t)):
        # print(s.shape)
        # print(t.shape)
        # assert s.shape == t.shape
        if layer == "outlayer":
            idx = -1
        losses.append(self.get_dis_loss(s, t, idx) * self.alpha_mgd)
    loss = sum(losses)
    return loss

def get_dis_loss(self, preds_S, preds_T, idx):
    loss_mse = nn.MSELoss(reduction='sum')
    N, C, H, W = preds_T.shape

    device = preds_S.device
    mat = torch.rand((N, 1, H, W)).to(device)
    mat = torch.where(mat > 1 - self.lambda_mgd, 0, 1).to(device)

    masked_fea = torch.mul(preds_S, mat)
    new_fea = self.generation[idx](masked_fea)

    dis_loss = loss_mse(new_fea, preds_T) / N
    return dis_loss

Additional

No response

Are you willing to submit a PR?

  • Yes I'd like to help by submitting a PR!

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions