Skip to content

Fix In-place Assignments for PiecewiseRationalQuadratic Compatibility with functorch and torch2.0#77

Open
HamidrezaKmK wants to merge 1 commit intobayesiains:masterfrom
HamidrezaKmK:master
Open

Fix In-place Assignments for PiecewiseRationalQuadratic Compatibility with functorch and torch2.0#77
HamidrezaKmK wants to merge 1 commit intobayesiains:masterfrom
HamidrezaKmK:master

Conversation

@HamidrezaKmK
Copy link

Hello,

While working on my project (OOD Detection using Manifolds), I noticed an issue with the current implementation of the PiecewiseRationalQuadratic coupling layers. The in-place masked assignments, specifically:

outputs[outside_interval_mask] = inputs[outside_interval_mask]

pose challenges when constructing the computation graph in both functorch and torch2.0.

To address this, I've made necessary modifications in my fork, ensuring a functional adaptation that aligns with these libraries without altering the primary functionality of the layer.

I kindly ask for your review of these changes. If they align with your vision and maintain the library's integrity, I'd appreciate their incorporation into the main branch.

Thank you for your time and consideration.

Best,
Hamid

Copy link
Contributor

@arturbekasov arturbekasov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Hamid,

Thanks for taking the time for submit the PR. Left a few comments.

I am not against making the code more functional, but I would like to be a bit careful about performance implications of the additional copies. Is there a chance to see some before/after performance stats of the function?

Cheers,

Artur

min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
enable_identity_init=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this on purpose, or we should re-base the changes?

outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0
outputs = torch.where(outside_interval_mask, inputs, outputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we drop the zero init above if we're copying here anyway? I.e. outputs = torch.where(outside_interval_mask, inputs, torch.zeros_like(inputs)).

logabsdet[outside_interval_mask] = 0
outputs = torch.where(outside_interval_mask, inputs, outputs)
logabsdet = torch.where(outside_interval_mask, torch.zeros_like(logabsdet), logabsdet)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, does the original line even have an effect? We're assigning zeros to what is already initialized to zeros.

(
outputs[inside_interval_mask],
logabsdet[inside_interval_mask],
# outputs[inside_interval_mask],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not leave unused code in comments.

cumwidths[..., -1] = right
widths = cumwidths[..., 1:] - cumwidths[..., :-1]

if enable_identity_init: #flow is the identity if initialized with parameters equal to zero
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as earlier comment: we want to keep this, not sure if this is on purpose.

)


# turn inside_interval_mask into an int tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we simplify below by using masked_scatter?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants