Skip to content

Commit db7b273

Browse files
authored
Sort imports. (#16)
* Removed unused imports. * Removed duplicate imports as flaged in #13 * Run isort to organize the imports.
1 parent cadb213 commit db7b273

File tree

8 files changed

+32
-34
lines changed

8 files changed

+32
-34
lines changed

GPTQ.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33

44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import os
7+
import sys
8+
69
import torch
7-
import os, sys
10+
811
lm_evaluation_harness_path = "/".join(
912
os.getcwd().split("/")[:-1] + ["lm-evaluation-harness"]
1013
)
1114
sys.path.insert(0, lm_evaluation_harness_path)
1215
import main as lm_evaluation_harness_main
13-
1416
import torch.fx as fx
1517
import torch.nn as nn
1618
import torch.nn.functional as F

eval.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from typing import Optional
1010

1111
import torch
12-
13-
import torch._inductor.config
1412
import torch._dynamo.config
13+
import torch._inductor.config
14+
1515
torch._dynamo.config.automatic_dynamic_shapes = True
1616
torch._inductor.config.triton.unique_kernel_names = True
1717
torch._inductor.config.epilogue_fusion = False
@@ -22,23 +22,21 @@
2222
wd = Path(__file__).parent.parent.resolve()
2323
sys.path.append(str(wd))
2424

25-
from model import LLaMA
26-
from sentencepiece import SentencePieceProcessor
27-
2825
# hacky path setup for lm-evaluation-harness
2926
import os
3027
import sys
28+
29+
from sentencepiece import SentencePieceProcessor
30+
31+
from model import LLaMA
32+
3133
lm_evaluation_harness_path = '/'.join(
3234
os.getcwd().split('/')[:-1] + ['lm-evaluation-harness'])
3335
sys.path.insert(0, lm_evaluation_harness_path)
34-
import main as lm_evaluation_harness_main
3536
import lm_eval
37+
import main as lm_evaluation_harness_main
3638

37-
from generate import (
38-
_load_model,
39-
encode_tokens,
40-
model_forward,
41-
)
39+
from generate import _load_model, encode_tokens, model_forward
4240

4341

4442
def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(

generate.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33

44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import itertools
67
import sys
78
import time
89
from pathlib import Path
910
from typing import Optional, Tuple
10-
import itertools
11-
import torch
1211

13-
import torch._inductor.config
12+
import torch
1413
import torch._dynamo.config
14+
import torch._inductor.config
15+
1516
torch._inductor.config.coordinate_descent_tuning = True
1617
torch._inductor.config.triton.unique_kernel_names = True
1718
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
@@ -21,9 +22,11 @@
2122
wd = Path(__file__).parent.parent.resolve()
2223
sys.path.append(str(wd))
2324

25+
from sentencepiece import SentencePieceProcessor
26+
2427
from model import Transformer
2528
from tp import maybe_init_dist
26-
from sentencepiece import SentencePieceProcessor
29+
2730

2831
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
2932
q = torch.empty_like(probs_sort).exponential_(1)

model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33

44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
6-
import math
76
from dataclasses import dataclass
87
from typing import Optional
98

109
import torch
1110
import torch.nn as nn
12-
from torch.nn import functional as F
1311
from torch import Tensor
12+
from torch.nn import functional as F
13+
1414

1515
def find_multiple(n: int, k: int) -> int:
1616
if n % k == 0:

quantize.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,12 @@
33

44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
6-
import importlib
76
import time
8-
from math import ceil
97
from pathlib import Path
108

119
import torch
12-
import importlib
13-
import time
14-
1510
import torch.nn as nn
1611
import torch.nn.functional as F
17-
18-
from pathlib import Path
1912
from sentencepiece import SentencePieceProcessor
2013

2114
try:

scripts/convert_hf_checkpoint.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,20 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66
import json
7+
import re
78
import sys
89
from pathlib import Path
910
from typing import Optional
1011

1112
import torch
12-
import re
1313

1414
# support running without installing as a package
1515
wd = Path(__file__).parent.parent.resolve()
1616
sys.path.append(str(wd))
1717

1818
from model import ModelArgs
1919

20+
2021
@torch.inference_mode()
2122
def convert_hf_checkpoint(
2223
*,

scripts/download.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66
import os
7-
from requests.exceptions import HTTPError
8-
import sys
9-
from pathlib import Path
107
from typing import Optional
118

9+
from requests.exceptions import HTTPError
10+
11+
1212
def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None:
1313
from huggingface_hub import snapshot_download
1414
os.makedirs(f"checkpoints/{repo_id}", exist_ok=True)

tp.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66
import os
7-
from typing import Optional, List
7+
from typing import List, Optional
88

99
import torch
10-
from torch import nn
1110
import torch.distributed as dist
11+
from torch import nn
1212
from torch.distributed import _functional_collectives as funcol
13-
from model import Transformer, Attention, FeedForward
14-
from quantize import WeightOnlyInt4Linear, WeightOnlyInt8Linear
13+
14+
from model import Attention, FeedForward, Transformer
15+
from quantize import WeightOnlyInt4Linear
1516

1617

1718
def _get_rank() -> int:

0 commit comments

Comments
 (0)