File tree Expand file tree Collapse file tree 8 files changed +32
-34
lines changed Expand file tree Collapse file tree 8 files changed +32
-34
lines changed Original file line number Diff line number Diff line change 3
3
4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
+ import os
7
+ import sys
8
+
6
9
import torch
7
- import os , sys
10
+
8
11
lm_evaluation_harness_path = "/" .join (
9
12
os .getcwd ().split ("/" )[:- 1 ] + ["lm-evaluation-harness" ]
10
13
)
11
14
sys .path .insert (0 , lm_evaluation_harness_path )
12
15
import main as lm_evaluation_harness_main
13
-
14
16
import torch .fx as fx
15
17
import torch .nn as nn
16
18
import torch .nn .functional as F
Original file line number Diff line number Diff line change 9
9
from typing import Optional
10
10
11
11
import torch
12
-
13
- import torch ._inductor .config
14
12
import torch ._dynamo .config
13
+ import torch ._inductor .config
14
+
15
15
torch ._dynamo .config .automatic_dynamic_shapes = True
16
16
torch ._inductor .config .triton .unique_kernel_names = True
17
17
torch ._inductor .config .epilogue_fusion = False
22
22
wd = Path (__file__ ).parent .parent .resolve ()
23
23
sys .path .append (str (wd ))
24
24
25
- from model import LLaMA
26
- from sentencepiece import SentencePieceProcessor
27
-
28
25
# hacky path setup for lm-evaluation-harness
29
26
import os
30
27
import sys
28
+
29
+ from sentencepiece import SentencePieceProcessor
30
+
31
+ from model import LLaMA
32
+
31
33
lm_evaluation_harness_path = '/' .join (
32
34
os .getcwd ().split ('/' )[:- 1 ] + ['lm-evaluation-harness' ])
33
35
sys .path .insert (0 , lm_evaluation_harness_path )
34
- import main as lm_evaluation_harness_main
35
36
import lm_eval
37
+ import main as lm_evaluation_harness_main
36
38
37
- from generate import (
38
- _load_model ,
39
- encode_tokens ,
40
- model_forward ,
41
- )
39
+ from generate import _load_model , encode_tokens , model_forward
42
40
43
41
44
42
def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill (
Original file line number Diff line number Diff line change 3
3
4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
+ import itertools
6
7
import sys
7
8
import time
8
9
from pathlib import Path
9
10
from typing import Optional , Tuple
10
- import itertools
11
- import torch
12
11
13
- import torch . _inductor . config
12
+ import torch
14
13
import torch ._dynamo .config
14
+ import torch ._inductor .config
15
+
15
16
torch ._inductor .config .coordinate_descent_tuning = True
16
17
torch ._inductor .config .triton .unique_kernel_names = True
17
18
torch ._inductor .config .fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
21
22
wd = Path (__file__ ).parent .parent .resolve ()
22
23
sys .path .append (str (wd ))
23
24
25
+ from sentencepiece import SentencePieceProcessor
26
+
24
27
from model import Transformer
25
28
from tp import maybe_init_dist
26
- from sentencepiece import SentencePieceProcessor
29
+
27
30
28
31
def multinomial_sample_one_no_sync (probs_sort ): # Does multinomial sampling without a cuda synchronization
29
32
q = torch .empty_like (probs_sort ).exponential_ (1 )
Original file line number Diff line number Diff line change 3
3
4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
- import math
7
6
from dataclasses import dataclass
8
7
from typing import Optional
9
8
10
9
import torch
11
10
import torch .nn as nn
12
- from torch .nn import functional as F
13
11
from torch import Tensor
12
+ from torch .nn import functional as F
13
+
14
14
15
15
def find_multiple (n : int , k : int ) -> int :
16
16
if n % k == 0 :
Original file line number Diff line number Diff line change 3
3
4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
- import importlib
7
6
import time
8
- from math import ceil
9
7
from pathlib import Path
10
8
11
9
import torch
12
- import importlib
13
- import time
14
-
15
10
import torch .nn as nn
16
11
import torch .nn .functional as F
17
-
18
- from pathlib import Path
19
12
from sentencepiece import SentencePieceProcessor
20
13
21
14
try :
Original file line number Diff line number Diff line change 4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
import json
7
+ import re
7
8
import sys
8
9
from pathlib import Path
9
10
from typing import Optional
10
11
11
12
import torch
12
- import re
13
13
14
14
# support running without installing as a package
15
15
wd = Path (__file__ ).parent .parent .resolve ()
16
16
sys .path .append (str (wd ))
17
17
18
18
from model import ModelArgs
19
19
20
+
20
21
@torch .inference_mode ()
21
22
def convert_hf_checkpoint (
22
23
* ,
Original file line number Diff line number Diff line change 4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
import os
7
- from requests .exceptions import HTTPError
8
- import sys
9
- from pathlib import Path
10
7
from typing import Optional
11
8
9
+ from requests .exceptions import HTTPError
10
+
11
+
12
12
def hf_download (repo_id : Optional [str ] = None , hf_token : Optional [str ] = None ) -> None :
13
13
from huggingface_hub import snapshot_download
14
14
os .makedirs (f"checkpoints/{ repo_id } " , exist_ok = True )
Original file line number Diff line number Diff line change 4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
import os
7
- from typing import Optional , List
7
+ from typing import List , Optional
8
8
9
9
import torch
10
- from torch import nn
11
10
import torch .distributed as dist
11
+ from torch import nn
12
12
from torch .distributed import _functional_collectives as funcol
13
- from model import Transformer , Attention , FeedForward
14
- from quantize import WeightOnlyInt4Linear , WeightOnlyInt8Linear
13
+
14
+ from model import Attention , FeedForward , Transformer
15
+ from quantize import WeightOnlyInt4Linear
15
16
16
17
17
18
def _get_rank () -> int :
You can’t perform that action at this time.
0 commit comments