Skip to content

Add ELU_FWD pointwise op support#311

Open
rsuderman wants to merge 1 commit intoiree-org:mainfrom
rsuderman:pointwise_elu_fwd
Open

Add ELU_FWD pointwise op support#311
rsuderman wants to merge 1 commit intoiree-org:mainfrom
rsuderman:pointwise_elu_fwd

Conversation

@rsuderman
Copy link
Copy Markdown
Contributor

No description provided.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Rob Suderman <rob.suderman@gmail.com>
@rsuderman rsuderman requested a review from sjain-stanford April 3, 2026 23:12
}
case PointwiseAttr::Mode::ELU_FWD: {
double xD = static_cast<double>(x);
y = xD >= 0 ? xD : std::expm1(xD);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't matter but just to be precise with PyTorch semantics:

Suggested change
y = xD >= 0 ? xD : std::expm1(xD);
y = xD > 0 ? xD : std::expm1(xD);

https://docs.pytorch.org/docs/stable/generated/torch.nn.ELU.html#torch.nn.ELU

Comment on lines +1726 to +1733
constexpr std::string_view kEluSchema = R"(
{0}
%elu_alpha_{7} = torch.constant.float 1.000000e+00
%elu_scale_{7} = torch.constant.float 1.000000e+00
%elu_input_scale_{7} = torch.constant.float 1.000000e+00
{1} = {6} {2}, %elu_alpha_{7}, %elu_scale_{7}, %elu_input_scale_{7} : {3}, !torch.float, !torch.float, !torch.float -> {4}
{5}
)";
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need alpha to be configurable and not hardcoded (OK to have a default when not set).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants