Skip to content

Commit 4b9e9b7

Browse files
committed
[Fix] Fix TypedDict import error for Python < 3.12
1 parent 1dcd2a6 commit 4b9e9b7

File tree

4 files changed

+6
-6
lines changed

4 files changed

+6
-6
lines changed

xtuner/v1/module/attention/attn_outputs.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
from typing import TypedDict
2-
31
import torch
2+
from typing_extensions import TypedDict
43

54

65
class AttnOutputs(TypedDict, total=False):

xtuner/v1/ops/attn_imp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import traceback
22
from functools import lru_cache
3-
from typing import TypedDict
43

54
import torch
65
import torch.nn as nn
@@ -14,6 +13,7 @@
1413
from torch.nn.attention.flex_attention import (
1514
flex_attention as torch_flex_attention,
1615
)
16+
from typing_extensions import TypedDict
1717

1818
from transformers.models.llama.modeling_llama import repeat_kv
1919

xtuner/v1/rl/base/controller.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import math
22
import os
3-
from typing import Literal, TypedDict
3+
from typing import Literal
44

55
import ray
66
import torch
77
from ray.actor import ActorProxy
8+
from typing_extensions import TypedDict
89

910
from xtuner.v1.data_proto.sequence_context import SequenceContext
1011
from xtuner.v1.model.compose.base import BaseComposeConfig

xtuner/v1/rl/base/worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55
from itertools import chain
66
from pathlib import Path
7-
from typing import Dict, Iterable, List, TypeAlias, TypedDict, cast
7+
from typing import Dict, Iterable, List, TypeAlias, cast
88

99
import ray
1010
import requests
@@ -16,7 +16,7 @@
1616
from ray.actor import ActorClass, ActorProxy
1717
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
1818
from torch.distributed.tensor import DTensor
19-
from typing_extensions import NotRequired
19+
from typing_extensions import NotRequired, TypedDict
2020

2121
from transformers import AutoTokenizer
2222
from xtuner.v1.config.fsdp import FSDPConfig

0 commit comments

Comments
 (0)