Skip to content

Commit b51ff26

Browse files
authored
Add usage for gctf.optimizers
1 parent 9735a5c commit b51ff26

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

README.md

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ pip install gradient-centralization-tf
3434
Create a centralized gradients functions for a specified optimizer.
3535

3636
#### Arguments:
37-
- optimizer: a `tf.keras.optimizers.Optimizer object`. The optimizer you are using.
37+
- `optimizer`: a `tf.keras.optimizers.Optimizer object`. The optimizer you are using.
3838

3939
#### Example:
4040

@@ -53,15 +53,29 @@ could point `get_gradients` to this function. This is a modified version of
5353
`tf.keras.optimizers.Optimizer.get_gradients`.
5454

5555
#### Arguments:
56-
- optimizer: a `tf.keras.optimizers.Optimizer object`. The optimizer you are using.
57-
- loss: Scalar tensor to minimize.
58-
- params: List of variables.
56+
- `optimizer`: a `tf.keras.optimizers.Optimizer` object. The optimizer you are using.
57+
- `loss`: Scalar tensor to minimize.
58+
- `params`: List of variables.
5959

6060
#### Returns:
6161
A gradients tensor.
6262

63-
#### Reference:
64-
- [Yong et al., 2020](https://arxiv.org/abs/2004.01461)
63+
### [`gctf.optimizers`](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow/blob/main/gctf/optimizers.py)
64+
65+
Pre built updated optimizers implementing GC.
66+
67+
This module is speciially built for testing out GC and in most cases you would be using [`gctf.centralized_gradients_for_optimizer`](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow#gctfcentralized_gradients_for_optimizer) though this module implements `gctf.centralized_gradients_for_optimizer`. You can directly use all optimizers with [`tf.keras.optimizers`](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers) updated for GC.
68+
69+
#### Example:
70+
71+
```py
72+
>>> model.compile(optimizer = gctf.optimizers.adam(learning_rate = 0.01), ...)
73+
>>> model.compile(optimizer = gctf.optimizers.rmsprop(learning_rate = 0.01, rho = 0.91), ...)
74+
>>> model.compile(optimizer = gctf.optimizers.sgd(), ...)
75+
```
76+
77+
#### Returns:
78+
A `tf.keras.optimizers.Optimizer` object.
6579

6680
## Developing `gctf`
6781

0 commit comments

Comments
 (0)