Skip to content

Commit 6264fe8

Browse files
authored
Merge pull request #27 from rajveer43/patch-1
add docstrings
2 parents 8891f35 + 5b3d2b8 commit 6264fe8

File tree

5 files changed

+299
-45
lines changed

5 files changed

+299
-45
lines changed

medusa/model/kv_cache.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
class KVCache:
55
"""
66
A key-value cache for the model.
7+
78
This class provides a mechanism to maintain a growing cache of keys and values,
89
particularly useful for models that benefit from caching previous states,
910
like transformers during autoregressive decoding.
@@ -15,6 +16,8 @@ class KVCache:
1516

1617
def __init__(self, data, current_length):
1718
"""
19+
Initialize the KVCache.
20+
1821
Args:
1922
data (torch.Tensor): Initial tensor to store the keys and values.
2023
current_length (int): Initial length of the data.

medusa/model/medusa_model.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,16 @@
1010

1111

1212
class MedusaConfig(PretrainedConfig):
13+
"""
14+
Configuration class for Medusa model.
15+
16+
Args:
17+
medusa_num_heads (int, optional): Number of heads for the Medusa layer. Default is 2.
18+
medusa_num_layers (int, optional): Number of Medusa layers. Default is 1.
19+
base_model_name_or_path (str, optional): The name or path of the base model. Default is "lmsys/vicuna-7b-v1.3".
20+
**kwargs: Additional keyword arguments to be passed to the parent class constructor.
21+
"""
22+
1323
def __init__(
1424
self,
1525
medusa_num_heads=2,
@@ -24,10 +34,14 @@ def __init__(
2434

2535

2636
class ResBlock(nn.Module):
27-
"""A Residual Block module.
37+
"""
38+
A Residual Block module.
2839
2940
This module performs a linear transformation followed by a SiLU activation,
3041
and then adds the result to the original input, creating a residual connection.
42+
43+
Args:
44+
hidden_size (int): The size of the hidden layers in the block.
3145
"""
3246

3347
def __init__(self, hidden_size):
@@ -39,7 +53,8 @@ def __init__(self, hidden_size):
3953
self.act = nn.SiLU()
4054

4155
def forward(self, x):
42-
"""Forward pass of the ResBlock.
56+
"""
57+
Forward pass of the ResBlock.
4358
4459
Args:
4560
x (torch.Tensor): Input tensor.

0 commit comments

Comments
 (0)