Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 48 additions & 17 deletions src/deepgraphpose/models/fitdgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,13 @@ def fit_dgp_labeledonly(
all_data_batch_ids = []
all_joint_locs = []
video_names = []
all_visible_frames = []
all_hidden_frames = []
all_wt_batch_mask = []
all_visible_marker = []
all_hidden_marker = []
all_visible_marker_in_targets = []

for dataset_id in range(len(data_batcher.datasets)):
(visible_frame, hidden_frame, _, all_data_batch, joint_loc, wt_batch_mask,
all_marker_batch, addn_batch_info), d = \
Expand All @@ -485,33 +492,46 @@ def fit_dgp_labeledonly(
all_joint_locs.append(joint_loc)
# add the corresponding name of the view to a list of video_names (important that these added at the same time to their respective lists to preserve ordering)
video_names.append(data_batcher.datasets[dataset_id].video_name)
all_visible_frames.append(visible_frame)
all_hidden_frames.append(hidden_frame)
all_wt_batch_mask.append(wt_batch_mask)

visible_marker, hidden_marker, visible_marker_in_targets = addn_batch_info
all_visible_marker.append(visible_marker) #!
all_hidden_marker.append(hidden_marker) #!
all_visible_marker_in_targets.append(visible_marker_in_targets) #!

# make all_data_batch_ids a single ndarray
all_data_batch_ids = np.concatenate(all_data_batch_ids)
all_joint_locs = np.concatenate(all_joint_locs)

nt_batch = len(visible_frame) + len(hidden_frame)
visible_marker, hidden_marker, visible_marker_in_targets = addn_batch_info
all_frame = np.sort(list(visible_frame) + list(hidden_frame))
visible_frame_within_batch = [np.where(all_frame == i)[0][0] for i in visible_frame]
all_hidden_frames = np.concatenate(all_hidden_frames)
all_visible_frames = np.concatenate(all_visible_frames)
all_wt_batch_mask = np.concatenate(all_wt_batch_mask)
all_visible_marker = np.concatenate(all_visible_marker)
all_hidden_marker = np.concatenate(all_hidden_marker)
all_visible_marker_in_targets = np.concatenate(all_visible_marker_in_targets)

nt_batch = len(all_visible_frames) + len(all_hidden_frames)
#visible_marker, hidden_marker, visible_marker_in_targets = addn_batch_info
all_frame = np.sort(list(all_visible_frames) + list(all_hidden_frames))
visible_frame_within_batch = [np.where(all_frame == i)[0][0] for i in all_visible_frames]

# batch data for placeholders
if dgp_cfg.wt > 0:
vector_field = learn_wt(all_data_batch) # vector field from optical flow
vector_field = learn_wt(all_data_batch_ids) # vector field from optical flow
else:
vector_field = np.zeros((1,1,1))
wt_batch = np.ones(nt_batch - 1, ) * dgp_cfg.wt

# data augmentation for visible frames
if dgp_cfg.aug and dgp_cfg.wt == 0:
all_data_batch, joint_loc = data_aug(all_data_batch, visible_frame_within_batch, joint_loc, pipeline, dgp_cfg)
all_data_batch_ids, all_joint_locs = data_aug(all_data_batch_ids, visible_frame_within_batch, all_joint_locs, pipeline, dgp_cfg)

locref_targets_batch, locref_mask_batch = coord2map(pdata, joint_loc, nx_out, ny_out, nj)
locref_targets_batch, locref_mask_batch = coord2map(pdata, all_joint_locs, nx_out, ny_out, nj)
if locref_mask_batch.shape[0] != 0:
locref_targets_all_batch = np.zeros(
(len(all_frame), nx_out, ny_out, nj * 2))
locref_targets_all_batch[
visible_frame_within_batch, :, :, :] = locref_targets_batch
locref_targets_all_batch[visible_frame_within_batch, :, :, :] = locref_targets_batch
locref_mask_all_batch = np.zeros(
(len(all_frame), nx_out, ny_out, nj * 2))
locref_mask_all_batch[visible_frame_within_batch, :, :, :] = locref_mask_batch
Expand All @@ -531,10 +551,10 @@ def fit_dgp_labeledonly(
placeholders['targets']: all_joint_locs,
placeholders['locref_map']: locref_targets_all_batch,
placeholders['locref_mask']: locref_mask_all_batch,
placeholders['visible_marker_pl']: visible_marker,
placeholders['hidden_marker_pl']: hidden_marker,
placeholders['visible_marker_in_targets_pl']: visible_marker_in_targets,
placeholders['wt_batch_mask_pl']: wt_batch_mask,
placeholders['visible_marker_pl']: all_visible_marker,
placeholders['hidden_marker_pl']: all_hidden_marker,
placeholders['visible_marker_in_targets_pl']: all_visible_marker_in_targets,
placeholders['wt_batch_mask_pl']: all_wt_batch_mask,
placeholders['vector_field_tf']: vector_field,
placeholders['nt_batch_pl']: nt_batch,
placeholders['wt_batch_pl']: wt_batch,
Expand Down Expand Up @@ -1118,21 +1138,29 @@ def dgp_loss(data_batcher, dgp_cfg, placeholders):
if data_batcher.multiview:
F_dict = data_batcher.fundamental_mat_dict
num_pts_per_frame = targets_pred.shape[1]
num_pts_per_view = tf.dtypes.cast(num_pts_per_frame * nt_batch_pl, tf.int64) # need to cast this as an int64 for some reason or it breaks
num_pts_per_view = tf.dtypes.cast(num_pts_per_frame * nt_batch_pl / len(data_batcher.datasets), tf.int64) # need to cast this as an int64 for some reason or it breaks
#num_pts_per_view = 2
loss['epipolar_loss'] = 0
for key, F in F_dict.items():
v1_name, v2_name = key.split(data_batcher.F_dict_key_delim)
# get coordinates of predictions for video 1
name1_idx = tf.where(tf.equal(video_names, v1_name))[0][0]
v1_pts = targets_pred_marker[name1_idx * num_pts_per_view:name1_idx * num_pts_per_view + num_pts_per_view]
v1_pts = targets_pred_marker[(name1_idx * num_pts_per_view):(name1_idx * num_pts_per_view + num_pts_per_view)]
# get coordinates of predictions for video 2
name2_idx = tf.where(tf.equal(video_names, v2_name))[0][0]
v2_pts = targets_pred_marker[name2_idx * num_pts_per_view:name2_idx * num_pts_per_view + num_pts_per_view]
v2_pts = targets_pred_marker[(name2_idx * num_pts_per_view):(name2_idx * num_pts_per_view + num_pts_per_view)]
# compute epipolar loss. (every point in v1_pts should correspond to the same point in space as the point at
# the same index in v2_pts. I.e. v1_pts[n] and v2_pts[n] correspond to the same point in space)
print('********************************************')
print('num_pts_per_view: ', num_pts_per_view)
print('num_pts_per_frame: ', num_pts_per_frame)
print('v1_pts: ', v1_pts)
print('v2_pts: ', v2_pts)
print('********************************************')
epipolar_loss = compute_epipolar_loss(v1_pts, v2_pts, F)
loss['epipolar_loss'] += dgp_cfg.epipolar_wt * epipolar_loss


total_loss += loss['epipolar_loss']


Expand Down Expand Up @@ -1270,9 +1298,12 @@ def compute_epipolar_loss(v1_pts, v2_pts, F):
-------
scalar loss value of ||x'Fx||, the magnitude of the vector v2_pts•F•v1_pts
"""


# convert to homogeneous coordinates
ones = tf.ones_like(v1_pts)[:,0]
ones = tf.expand_dims(ones, axis=1)

im1_pts_hom = tf.concat([v1_pts, ones], axis=1)
im2_pts_hom = tf.concat([v2_pts, ones], axis=1)

Expand Down