diff --git a/Python/src/module.cpp b/Python/src/module.cpp index 6354ec95..3c52e6a9 100644 --- a/Python/src/module.cpp +++ b/Python/src/module.cpp @@ -554,7 +554,9 @@ PYBIND11_MODULE(pydirectml, module) py::arg("new_size"), py::arg("new_strides")); - module.def("activation_soft_max", &dml::ActivationSoftmax, "Raise all elements to e, and divide all the elements in each batch by that batch's sum.", + module.def("activation_soft_max", [](dml::Expression input) { + return dml::ActivationSoftmax(input); + }, "Raise all elements to e, and divide all the elements in each batch by that batch's sum.", py::arg("input")); module.def("join", [](