Softtorch is not yet fully released! We are currently finalizing the library, and are planning on releasing it (alongside a similar "Softjax" library) officially until the end of the year. If you somehow stumbled upon this library already, feel free to use and test the github code, and please reach out if you encounter any issues or have suggestions for improvement. Thanks!
Note also that some of the API and internals are still subject to potentially bigger changes until the official release. The pip install will also only be available after official release.
Softtorch provides soft differentiable drop-in replacements for traditionally non-differentiable functions in PyTorch, including
- elementwise operators:
abs,relu,clamp,sign,roundandheaviside; - tensor-valued operators:
(arg)max,(arg)min,(arg)median,(arg)sort,(arg)topkandranking; - comparison operators such as:
greater,equalorisclose; - logical operators such as:
logical_and,allorwhere; - functions for selection with indices such as:
where,take_along_dim,index_selectorchoose.
All operators offer multiple modes and adjustable strength of softening, allowing for e.g. smoothness of the soft function or boundedness of the softened region, depending on the user needs.
Moreover, we tightly integrate functionality for deploying functions using straight-through-estimation, where we use non-differentiable functions in the forward pass and their differentiable replacements in the backward pass.
The Softtorch library is designed to require minimal user effort, by simply replacing the non-differentiable PyTorch function with the Softtorch counterparts. However, keep in mind that special care needs to be taken when using functions operating on indices, as we relax the notion of an index into a distribution over indices, thereby modifying the shape of returned/accepted values.
Requires Python 3.10+.
pip install softtorch
Available at https://a-paulus.github.io/softtorch/.
import torch
import softtorch as sj
x = torch.tensor([-0.2, -1.0, 0.3, 1.0])
# Elementwise functions
print("\nTorch absolute:", torch.abs(x))
print("SoftTorch absolute (hard mode):", st.abs(x, mode="hard"))
print("SoftTorch absolute (soft mode):", st.abs(x))
print("\nTorch clamp:", torch.clamp(x, -0.5, 0.5))
print("SoftTorch clamp (hard mode):", st.clamp(x, -0.5, 0.5, mode="hard"))
print("SoftTorch clamp (soft mode):", st.clamp(x, -0.5, 0.5))
print("\nTorch heaviside:", torch.heaviside(x, torch.tensor(0.5)))
print("SoftTorch heaviside (hard mode):", st.heaviside(x, mode="hard"))
print("SoftTorch heaviside (soft mode):", st.heaviside(x))
print("\nTorch ReLU:", torch.nn.functional.relu(x))
print("SoftTorch ReLU (hard mode):", st.relu(x, mode="hard"))
print("SoftTorch ReLU (soft mode):", st.relu(x))
print("\nTorch round:", torch.round(x))
print("SoftTorch round (hard mode):", st.round(x, mode="hard"))
print("SoftTorch round (soft mode):", st.round(x))
print("\nTorch sign:", torch.sign(x))
print("SoftTorch sign (hard mode):", st.sign(x, mode="hard"))
print("SoftTorch sign (soft mode):", st.sign(x))Torch absolute: tensor([0.2000, 1.0000, 0.3000, 1.0000])
SoftTorch absolute (hard mode): tensor([0.2000, 1.0000, 0.3000, 1.0000])
SoftTorch absolute (soft mode): tensor([0.1523, 0.9999, 0.2715, 0.9999])
Torch clamp: tensor([-0.2000, -0.5000, 0.3000, 0.5000])
SoftTorch clamp (hard mode): tensor([-0.2000, -0.5000, 0.3000, 0.5000])
SoftTorch clamp (soft mode): tensor([-0.1952, -0.4993, 0.2873, 0.4993])
Torch heaviside: tensor([0., 0., 1., 1.])
SoftTorch heaviside (hard mode): tensor([0., 0., 1., 1.])
SoftTorch heaviside (soft mode): tensor([1.1920e-01, 4.5398e-05, 9.5257e-01, 9.9995e-01])
Torch ReLU: tensor([0.0000, 0.0000, 0.3000, 1.0000])
SoftTorch ReLU (hard mode): tensor([0.0000, 0.0000, 0.3000, 1.0000])
SoftTorch ReLU (soft mode): tensor([1.2693e-02, 4.5399e-06, 3.0486e-01, 1.0000e+00])
Torch round: tensor([-0., -1., 0., 1.])
SoftTorch round (hard mode): tensor([-0., -1., 0., 1.])
SoftTorch round (soft mode): tensor([-0.0465, -1.0000, 0.1189, 1.0000])
Torch sign: tensor([-1., -1., 1., 1.])
SoftTorch sign (hard mode): tensor([-1., -1., 1., 1.])
SoftTorch sign (soft mode): tensor([-0.7616, -0.9999, 0.9051, 0.9999])
# Tensor-valued operators
print("\nTorch max:", torch.max(x))
print("SoftTorch max (hard mode):", st.max(x, mode="hard"))
print("SoftTorch max (soft mode):", st.max(x))
print("\nTorch min:", torch.min(x))
print("SoftTorch min (hard mode):", st.min(x, mode="hard"))
print("SoftTorch min (soft mode):", st.min(x))
print("\nTorch median:", torch.median(x))
print("SoftTorch median (hard mode):", st.median(x, mode="hard"))
print("SoftTorch median (soft mode):", st.median(x))
print("\nTorch sort:", torch.sort(x).values)
print("SoftTorch sort (hard mode):", st.sort(x, mode="hard").values)
print("SoftTorch sort (soft mode):", st.sort(x).values)
print("\nTorch topk:", torch.topk(x, k=2).values)
print("SoftTorch topk (hard mode):", st.topk(x, k=2, mode="hard").values)
print("SoftTorch topk (soft mode):", st.topk(x, k=2).values)
print("\nTorch ranking:", torch.argsort(torch.argsort(x)))
print("SoftTorch ranking (hard mode):", st.ranking(x, descending=False, mode="hard"))
print("SoftTorch ranking (soft mode):", st.ranking(x, descending=False))Torch max: tensor(1.)
SoftTorch max (hard mode): tensor(1.)
SoftTorch max (soft mode): tensor(0.9994)
Torch min: tensor(-1.)
SoftTorch min (hard mode): tensor(-1.)
SoftTorch min (soft mode): tensor(-0.9997)
Torch median: tensor(-0.2000)
SoftTorch median (hard mode): tensor(0.0500)
SoftTorch median (soft mode): tensor(0.0500)
Torch sort: tensor([-1.0000, -0.2000, 0.3000, 1.0000])
SoftTorch sort (hard mode): tensor([-1.0000, -0.2000, 0.3000, 1.0000])
SoftTorch sort (soft mode): tensor([-0.9997, -0.1969, 0.2973, 0.9994])
Torch topk: tensor([1.0000, 0.3000])
SoftTorch topk (hard mode): tensor([1.0000, 0.3000])
SoftTorch topk (soft mode): tensor([0.9994, 0.2973])
Torch ranking: tensor([1, 0, 2, 3])
SoftTorch ranking (hard mode): tensor([1., 0., 2., 3.])
SoftTorch ranking (soft mode): tensor([1.0064e+00, 3.3987e-04, 1.9942e+00, 2.9991e+00])
# Operators returning indices
print("\nTorch argmax:", torch.argmax(x))
print("SoftTorch argmax (hard mode):", st.argmax(x, mode="hard"))
print("SoftTorch argmax (soft mode):", st.argmax(x))
print("\nTorch argmin:", torch.argmin(x))
print("SoftTorch argmin (hard mode):", st.argmin(x, mode="hard"))
print("SoftTorch argmin (soft mode):", st.argmin(x))
print("\nTorch argmedian:", torch.median(x, dim=0).indices)
print("SoftTorch argmedian (hard mode):", st.median(x, mode="hard", dim=0).indices)
print("SoftTorch argmedian (soft mode):", st.median(x, dim=0).indices)
print("\nTorch argsort:", torch.argsort(x))
print("SoftTorch argsort (hard mode):", st.argsort(x, mode="hard"))
print("SoftTorch argsort (soft mode):", st.argsort(x))
print("\nTorch argtopk:", torch.topk(x, k=2).indices)
print("SoftTorch argtopk (hard mode):", st.topk(x, k=2, mode="hard").indices)
print("SoftTorch argtopk (soft mode):", st.topk(x, k=2).indices)Torch argmax: tensor(3)
SoftTorch argmax (hard mode): tensor([0., 0., 0., 1.])
SoftTorch argmax (soft mode): tensor([6.1386e-06, 2.0593e-09, 9.1105e-04, 9.9908e-01])
Torch argmin: tensor(1)
SoftTorch argmin (hard mode): tensor([0., 1., 0., 0.])
SoftTorch argmin (soft mode): tensor([3.3535e-04, 9.9966e-01, 2.2596e-06, 2.0605e-09])
Torch argmedian: tensor(0)
SoftTorch argmedian (hard mode): tensor([0.5000, 0.0000, 0.5000, 0.0000])
SoftTorch argmedian (soft mode): tensor([5.0000e-01, 5.6268e-08, 5.0000e-01, 4.1576e-07])
Torch argsort: tensor([1, 0, 2, 3])
SoftTorch argsort (hard mode): tensor([[0., 1., 0., 0.],
[1., 0., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.]])
SoftTorch argsort (soft mode): tensor([[3.3535e-04, 9.9966e-01, 2.2596e-06, 2.0605e-09],
[9.9297e-01, 3.3310e-04, 6.6906e-03, 6.1010e-06],
[6.6868e-03, 2.2432e-06, 9.9241e-01, 9.0496e-04],
[6.1386e-06, 2.0593e-09, 9.1105e-04, 9.9908e-01]])
Torch argtopk: tensor([3, 2])
SoftTorch argtopk (hard mode): tensor([[0., 0., 0., 1.],
[0., 0., 1., 0.]])
SoftTorch argtopk (soft mode): tensor([[6.1386e-06, 2.0593e-09, 9.1105e-04, 9.9908e-01],
[6.6868e-03, 2.2432e-06, 9.9241e-01, 9.0496e-04]])
y = torch.tensor([0.2, -0.5, 0.5, -1.0])
# Comparison operators
print("\nTorch greater:", torch.greater(x, y))
print("SoftTorch greater (hard mode):", st.greater(x, y, mode="hard"))
print("SoftTorch greater (soft mode):", st.greater(x, y))
print("\nTorch greater equal:", torch.greater_equal(x, y))
print("SoftTorch greater equal (hard mode):", st.greater_equal(x, y, mode="hard"))
print("SoftTorch greater equal (soft mode):", st.greater_equal(x, y))
print("\nTorch less:", torch.less(x, y))
print("SoftTorch less (hard mode):", st.less(x, y, mode="hard"))
print("SoftTorch less (soft mode):", st.less(x, y))
print("\nTorch less equal:", torch.less_equal(x, y))
print("SoftTorch less equal (hard mode):", st.less_equal(x, y, mode="hard"))
print("SoftTorch less equal (soft mode):", st.less_equal(x, y))
print("\nTorch equal:", torch.equal(x, y))
print("SoftTorch equal (hard mode):", st.equal(x, y, mode="hard"))
print("SoftTorch equal (soft mode):", st.equal(x, y))
print("\nTorch not equal:", torch.not_equal(x, y))
print("SoftTorch not equal (hard mode):", st.not_equal(x, y, mode="hard"))
print("SoftTorch not equal (soft mode):", st.not_equal(x, y))
print("\nTorch isclose:", torch.isclose(x, y))
print("SoftTorch isclose (hard mode):", st.isclose(x, y, mode="hard"))
print("SoftTorch isclose (soft mode):", st.isclose(x, y))Torch greater: tensor([False, False, False, True])
SoftTorch greater (hard mode): tensor([0., 0., 0., 1.])
SoftTorch greater (soft mode): tensor([0.0180, 0.0067, 0.1192, 1.0000])
Torch greater equal: tensor([False, False, False, True])
SoftTorch greater equal (hard mode): tensor([0., 0., 0., 1.])
SoftTorch greater equal (soft mode): tensor([0.0180, 0.0067, 0.1192, 1.0000])
Torch less: tensor([ True, True, True, False])
SoftTorch less (hard mode): tensor([1., 1., 1., 0.])
SoftTorch less (soft mode): tensor([0.9820, 0.9933, 0.8808, 0.0000])
Torch less equal: tensor([ True, True, True, False])
SoftTorch less equal (hard mode): tensor([1., 1., 1., 0.])
SoftTorch less equal (soft mode): tensor([0.9820, 0.9933, 0.8808, 0.0000])
Torch equal: False
SoftTorch equal (hard mode): tensor([0., 0., 0., 0.])
SoftTorch equal (soft mode): tensor([0.0180, 0.0067, 0.1192, 0.0000])
Torch not equal: tensor([True, True, True, True])
SoftTorch not equal (hard mode): tensor([1., 1., 1., 1.])
SoftTorch not equal (soft mode): tensor([0.9820, 0.9933, 0.8808, 1.0000])
Torch isclose: tensor([False, False, False, False])
SoftTorch isclose (hard mode): tensor([0., 0., 0., 0.])
SoftTorch isclose (soft mode): tensor([0.0180, 0.0067, 0.1192, 0.0000])
# Logical operators
fuzzy_a = torch.tensor([0.1, 0.2, 0.8, 1.0])
fuzzy_b = torch.tensor([0.7, 0.3, 0.1, 0.9])
bool_a = fuzzy_a >= 0.5
bool_b = fuzzy_b >= 0.5
print("\nTorch AND:", torch.logical_and(bool_a, bool_b))
print("SoftTorch AND:", st.logical_and(fuzzy_a, fuzzy_b))
print("\nTorch OR:", torch.logical_or(bool_a, bool_b))
print("SoftTorch OR:", st.logical_or(fuzzy_a, fuzzy_b))
print("\nTorch NOT:", torch.logical_not(bool_a))
print("SoftTorch NOT:", st.logical_not(fuzzy_a))
print("\nTorch XOR:", torch.logical_xor(bool_a, bool_b))
print("SoftTorch XOR:", st.logical_xor(fuzzy_a, fuzzy_b))
print("\nTorch ALL:", torch.all(bool_a))
print("SoftTorch ALL:", st.all(fuzzy_a))
print("\nTorch ANY:", torch.any(bool_a))
print("SoftTorch ANY:", st.any(fuzzy_a))
# Selection operators
print("SoftTorch Where:", st.where(fuzzy_a, x, y))Torch AND: tensor([False, False, False, True])
SoftTorch AND: tensor([0.2646, 0.2449, 0.2828, 0.9487])
Torch OR: tensor([ True, False, True, True])
SoftTorch OR: tensor([0.4804, 0.2517, 0.5757, 1.0000])
Torch NOT: tensor([ True, True, False, False])
SoftTorch NOT: tensor([0.9000, 0.8000, 0.2000, 0.0000])
Torch XOR: tensor([ True, False, True, False])
SoftTorch XOR: tensor([0.5870, 0.4350, 0.6394, 0.1731])
Torch ALL: tensor(False)
SoftTorch ALL: tensor(0.3557)
Torch ANY: tensor(True)
SoftTorch ANY: tensor(0.9981)
SoftTorch Where: tensor([ 0.1600, -0.6000, 0.3400, 1.0000])
# Straight-through operators: Use hard function on forward and soft on backward
print("Straight-through ReLU:", st.relu_st(x))
print("Straight-through sort:", st.sort_st(x))
print("Straight-through topk:", st.topk_st(x, k=3))
print("Straight-through greater:", st.greater_st(x, y))
# And many more...Straight-through ReLU: tensor([0.0000, 0.0000, 0.3000, 1.0000])
Straight-through sort: torch.return_types.sort(
values=tensor([-1.0000, -0.2000, 0.3000, 1.0000]),
indices=tensor([[0., 1., 0., 0.],
[1., 0., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.]]))
Straight-through topk: torch.return_types.topk(
values=tensor([ 1.0000, 0.3000, -0.2000]),
indices=tensor([[0., 0., 0., 1.],
[0., 0., 1., 0.],
[1., 0., 0., 0.]]))
Straight-through greater: tensor([0., 0., 0., 1.])
If this library helped your academic work, please consider citing:
@misc{Softtorch2025,
author = {Paulus, Anselm and Geist, Ren\'e and Martius, Georg},
title = {Softorch},
year = {2025},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/a-paulus/softtorch}}
}Also consider starring the project on GitHub!
Special thanks and credit go to Patrick Kidger for the awesome JAX repositories that served as the basis for the documentation of this project.
This project is still relatively young, if you have any suggestions for improvement or other feedback, please reach out or raise a GitHub issue!
Differentiable sorting, top-k and ranking
DiffSort: Differentiable sorting networks in PyTorch.
DiffTopK: Differentiable top-k in PyTorch.
FastSoftSort: Fast differentiable sorting and ranking in JAX.
Differentiable Top-k with Optimal Transport in JAX.
SoftSort: Differentiable argsort in PyTorch and TensorFlow.
Other
DiffLogic: Differentiable logic gate networks in PyTorch.
SmoothOT: Smooth and Sparse Optimal Transport.
JaxOpt: Differentiable optimization in JAX.
Softjax builds on / implements various different algoithms for e.g. differentiable topk, sorting and ranking, including:
Projection onto the probability simplex: An efficient algorithm with a simple proof, and an application
Fast Differentiable Sorting and Ranking.
Differentiable Ranks and Sorting using Optimal Transport
Differentiable Top-k with Optimal Transport
SoftSort: A Continuous Relaxation for the argsort Operator
Sinkhorn Distances: Lightspeed Computation of Optimal Transportation Distances
Smooth and Sparse Optimal Transport
Smooth Approximations of the Rounding Function
Please check the API Documentation for implementation details.