File tree Expand file tree Collapse file tree 1 file changed +13
-1
lines changed Expand file tree Collapse file tree 1 file changed +13
-1
lines changed Original file line number Diff line number Diff line change @@ -3283,7 +3283,7 @@ def get_least_asym_entity_or_longest_length(
32833283
32843284 # Calculate entity length
32853285 entity_mask = batch ["entity_id" ] == entity_id
3286- entity_length [int (entity_id )] = entity_mask .sum () .item ()
3286+ entity_length [int (entity_id )] = entity_mask .sum (- 1 ). mode (). values .item ()
32873287
32883288 min_asym_count = min (entity_asym_count .values ())
32893289 least_asym_entities = [
@@ -3313,6 +3313,18 @@ def get_least_asym_entity_or_longest_length(
33133313 if asym_id in input_asym_id
33143314 ]
33153315
3316+ # Since the entity ID to asym ID mapping is many-to-many, we need to select only
3317+ # prediction asym IDs with equal length w.r.t. the sampled ground truth asym ID
3318+ anchor_gt_asym_id_length = (
3319+ (batch ["asym_id" ] == anchor_gt_asym_id ).sum (- 1 ).mode ().values .item ()
3320+ )
3321+ anchor_pred_asym_ids = [
3322+ asym_id
3323+ for asym_id in anchor_pred_asym_ids
3324+ if (batch ["asym_id" ] == asym_id ).sum (- 1 ).mode ().values .item ()
3325+ == anchor_gt_asym_id_length
3326+ ]
3327+
33163328 # Remap `asym_id` values to remove any gaps in the ground truth asym IDs,
33173329 # but leave the prediction asym IDs as is since they are used for masking
33183330 sorted_asym_ids = sorted (all_asym_ids )
You can’t perform that action at this time.
0 commit comments