Skip to content

input sizes mismatch for nn.CrossEntropyLoss() #115

@ghost

Description

Hi @natanielruiz thanks for sharing your work,

while trying to adapt it for a project i stumbled upon a problem.
In lines 160 to 161 of your train_hopenet.py file stand:

# Forward pass
yaw, pitch, roll = model(images)

If i am not mistaken the size of each by the model predicted angle is (batch_size, num_bins) so for example (128, 66).
Which makes absolute fine sense. Because the fully connected layer is of output_size 66.

While investigating the datahandling in the datasets.py there is the following codeblock

# We get the pose in radians
pose = utils.get_ypr_from_mat(mat_path)
# And convert to degrees.
pitch = pose[0] * 180 / np.pi
yaw = pose[1] * 180 / np.pi
roll = pose[2] * 180 / np.pi
# Bin values
bins = np.array(range(-99, 102, 3))
labels = torch.LongTensor(np.digitize([yaw, pitch, roll], bins) - 1)

While assuming that the pose of the head has 3 values, one for each angle.
Then i would get the bin of each angle in the labels variable, like [30, 33, 33].

The first codeblock is followed combined with:

label_yaw = Variable(labels[:,0]).cuda(gpu)
label_pitch = Variable(labels[:,1]).cuda(gpu)
label_roll = Variable(labels[:,2]).cuda(gpu)

# Continuous labels
label_yaw_cont = Variable(cont_labels[:,0]).cuda(gpu)
label_pitch_cont = Variable(cont_labels[:,1]).cuda(gpu)
label_roll_cont = Variable(cont_labels[:,2]).cuda(gpu)

# Cross entropy loss
loss_yaw = criterion(yaw, label_yaw)
loss_pitch = criterion(pitch, label_pitch)
loss_roll = criterion(roll, label_roll)

with the criterion being nn.CrossEntropyLoss().cuda(gpu).

This is where i get confused, because the sizes of the inputs do not match? We have yaw with (128, 66) but the label_yaw is of size (128, 1).

Could you please tell me where i am doing something wrong?
Any help is appreciated.

Kind regards

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions