Skip to content

Commit 5c48bf4

Browse files
author
Muhammed Hasan Celik
committed
doc update
1 parent 98472ad commit 5c48bf4

File tree

4 files changed

+42
-4
lines changed

4 files changed

+42
-4
lines changed

src/grelu/model/blocks.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,8 @@ class UnetBlock(nn.Module):
870870
y_in_channels: Number of channels in the higher-resolution representation.
871871
norm_type: Type of normalization to apply: 'batch', 'syncbatch', 'layer', 'instance' or None
872872
norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layers
873-
act_func: Name of the activation function
873+
act_func: Name of the activation function. Defaults to 'gelu_borzoi' which uses
874+
tanh approximation (different from PyTorch's default GELU implementation).
874875
dtype: Data type of the weights
875876
device: Device on which to store the weights
876877
"""
@@ -938,7 +939,8 @@ class UnetTower(nn.Module):
938939
in_channels: Number of channels in the input
939940
y_in_channels: Number of channels in the higher-resolution representations.
940941
n_blocks: Number of U-net blocks
941-
act_func: Name of the activation function
942+
act_func: Name of the activation function. Defaults to 'gelu_borzoi' which uses
943+
tanh approximation (different from PyTorch's default GELU implementation).
942944
kwargs: Additional arguments to be passed to the U-net blocks
943945
"""
944946

src/grelu/model/layers.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,15 @@ class Activation(nn.Module):
2020
A nonlinear activation layer.
2121
2222
Args:
23-
func: The type of activation function. Supported values are 'relu',
24-
'elu', 'softplus', 'gelu', 'gelu_borzoi', 'gelu_enformer' and 'exp'. If None, will return nn.Identity.
23+
func: The type of activation function. Supported values are:
24+
- 'relu': Standard ReLU activation
25+
- 'elu': Exponential Linear Unit
26+
- 'softplus': Softplus activation
27+
- 'gelu': Standard GELU activation using PyTorch's default approximation
28+
- 'gelu_borzoi': GELU activation using tanh approximation (different from PyTorch's default)
29+
- 'gelu_enformer': Custom GELU implementation from Enformer
30+
- 'exp': Exponential activation
31+
- None: Returns identity function (no activation)
2532
2633
Raises:
2734
NotImplementedError: If 'func' is not a supported activation function.
@@ -159,6 +166,14 @@ class Norm(nn.Module):
159166
'syncbatch', 'instance', or 'layer'. If None, will return nn.Identity.
160167
in_dim: Number of features in the input tensor.
161168
**kwargs: Additional arguments to pass to the normalization function.
169+
Common arguments include:
170+
- eps: Small constant added to denominator for numerical stability.
171+
Defaults to 1e-5 for all normalization types unless overridden.
172+
- momentum: Value used for the running_mean and running_var computation.
173+
Defaults to 0.1 for batch and sync batch norm.
174+
- affine: If True, adds learnable affine parameters. Defaults to True.
175+
- track_running_stats: If True, tracks running mean and variance.
176+
Defaults to True for batch and sync batch norm.
162177
"""
163178

164179
def __init__(

src/grelu/model/models.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,10 @@ class BorzoiModel(BaseModel):
496496
If None, no pooling will be applied at the end.
497497
flash_attn: If True, uses Flash Attention with Rotational Position Embeddings. key_len, value_len,
498498
pos_dropout and n_pos_features are ignored.
499+
norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layers.
500+
Defaults to {"eps": 0.001}.
501+
act_func: Name of the activation function. Defaults to 'gelu_borzoi' which uses
502+
tanh approximation (different from PyTorch's default GELU implementation).
499503
dtype: Data type for the layers.
500504
device: Device for the layers.
501505
"""
@@ -570,6 +574,19 @@ def __init__(
570574
class BorzoiPretrainedModel(BaseModel):
571575
"""
572576
Borzoi model with published weights (ported from Keras).
577+
578+
Args:
579+
n_tasks: Number of tasks for the model to predict
580+
fold: Which fold of the model to load (default=0)
581+
n_transformers: Number of transformer blocks to use (default=8)
582+
crop_len: Number of positions to crop at either end of the output (default=0)
583+
act_func: Name of the activation function. Defaults to 'gelu_borzoi' which uses
584+
tanh approximation (different from PyTorch's default GELU implementation).
585+
norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layers.
586+
Defaults to {"eps": 0.001}.
587+
final_pool_func: Name of the pooling function to apply to the final output (default="avg")
588+
dtype: Data type for the layers
589+
device: Device for the layers
573590
"""
574591

575592
def __init__(

src/grelu/model/trunks/borzoi.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ class BorzoiConvTower(nn.Module):
2222
n_blocks: Number of convolutional/pooling blocks, including the stem
2323
norm_type: Type of normalization to apply: 'batch', 'syncbatch', 'layer', 'instance' or None
2424
norm_kwargs: Additional arguments to be passed to the normalization layer
25+
act_func: Name of the activation function. Defaults to 'gelu_borzoi' which uses
26+
tanh approximation (different from PyTorch's default GELU implementation).
2527
dtype: Data type for the layers.
2628
device: Device for the layers.
2729
"""
@@ -123,6 +125,8 @@ class BorzoiTrunk(nn.Module):
123125
pos_dropout and n_pos_features are ignored.
124126
norm_type: Type of normalization to apply: 'batch', 'syncbatch', 'layer', 'instance' or None
125127
norm_kwargs: Additional arguments to be passed to the normalization layer
128+
act_func: Name of the activation function. Defaults to 'gelu_borzoi' which uses
129+
tanh approximation (different from PyTorch's default GELU implementation).
126130
dtype: Data type for the layers.
127131
device: Device for the layers.
128132
"""

0 commit comments

Comments
 (0)