Skip to content

Commit c92f60d

Browse files
Fix some type issue
1 parent ba37d80 commit c92f60d

File tree

3 files changed

+52
-35
lines changed

3 files changed

+52
-35
lines changed

cyaron/io.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import re
99
import subprocess
1010
import tempfile
11-
from typing import Union, overload
11+
from typing import Union, overload, Optional
1212
from io import IOBase
1313
from . import log
1414
from .utils import list_like, make_unicode
@@ -19,29 +19,30 @@ class IO:
1919

2020
@overload
2121
def __init__(self,
22-
input_file: Union[IOBase, str, int, None] = None,
23-
output_file: Union[IOBase, str, int, None] = None,
24-
data_id: Union[str, None] = None,
22+
input_file: Optional[Union[IOBase, str, int]] = None,
23+
output_file: Optional[Union[IOBase, str, int]] = None,
24+
data_id: Optional[int] = None,
2525
disable_output: bool = False):
2626
...
2727

2828
@overload
2929
def __init__(self,
30-
data_id: Union[str, None] = None,
31-
file_prefix: Union[str, None] = None,
32-
input_suffix: Union[str, None] = '.in',
33-
output_suffix: Union[str, None] = '.out',
30+
data_id: Optional[int] = None,
31+
file_prefix: Optional[str] = None,
32+
input_suffix: str = '.in',
33+
output_suffix: str = '.out',
3434
disable_output: bool = False):
3535
...
3636

37-
def __init__(self,
38-
input_file: Union[IOBase, str, int, None] = None,
39-
output_file: Union[IOBase, str, int, None] = None,
40-
data_id: Union[str, None] = None,
41-
file_prefix: Union[str, None] = None,
42-
input_suffix: Union[str, None] = '.in',
43-
output_suffix: Union[str, None] = '.out',
44-
disable_output: bool = False):
37+
def __init__( # type: ignore
38+
self,
39+
input_file: Optional[Union[IOBase, str, int]] = None,
40+
output_file: Optional[Union[IOBase, str, int]] = None,
41+
data_id: Optional[int] = None,
42+
file_prefix: Optional[str] = None,
43+
input_suffix: str = '.in',
44+
output_suffix: str = '.out',
45+
disable_output: bool = False):
4546
"""
4647
Args:
4748
input_file (optional): input file object or filename or file descriptor.
@@ -216,6 +217,8 @@ def output_gen(self, shell_cmd, time_limit=None):
216217
time_limit: the time limit (seconds) of the command to run.
217218
None means infinity. Defaults to None.
218219
"""
220+
if self.output_file is None:
221+
raise ValueError("Output file is disabled")
219222
self.flush_buffer()
220223
origin_pos = self.input_file.tell()
221224
self.input_file.seek(0)
@@ -224,16 +227,16 @@ def output_gen(self, shell_cmd, time_limit=None):
224227
shell_cmd,
225228
shell=True,
226229
timeout=time_limit,
227-
stdin=self.input_file,
228-
stdout=self.output_file,
230+
stdin=self.input_file.fileno(),
231+
stdout=self.output_file.fileno(),
229232
universal_newlines=True,
230233
)
231234
else:
232235
subprocess.check_call(
233236
shell_cmd,
234237
shell=True,
235-
stdin=self.input_file,
236-
stdout=self.output_file,
238+
stdin=self.input_file.fileno(),
239+
stdout=self.output_file.fileno(),
237240
universal_newlines=True,
238241
)
239242
self.input_file.seek(origin_pos)
@@ -248,6 +251,8 @@ def output_write(self, *args, **kwargs):
248251
*args: the elements to write
249252
separator: a string used to separate every element. Defaults to " ".
250253
"""
254+
if self.output_file is None:
255+
raise ValueError("Output file is disabled")
251256
self.__write(self.output_file, *args, **kwargs)
252257

253258
def output_writeln(self, *args, **kwargs):

cyaron/sequence.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ class Sequence:
1919

2020
def __init__(self,
2121
formula: Callable[[int, Callable[[int], T]], T],
22-
initial_values: Optional[Union[List[T], Tuple[T, ...],
23-
Dict[int, T]]] = ()):
22+
initial_values: Union[List[T], Tuple[T, ...], Dict[int,
23+
T]] = ()):
2424
"""
2525
Initialize a sequence object.
2626
Parameters:

cyaron/vector.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
import random
66
from enum import IntEnum
7-
from typing import Union, Tuple, List, Set
7+
from typing import Sequence, Union, Tuple, List, Set
8+
from typing import cast as typecast
89

910
from .utils import list_like
1011

@@ -48,28 +49,32 @@ def random(
4849

4950
dimension = len(position_range)
5051

51-
offset: List[_Number] = []
52-
length: List[_Number] = []
52+
offset: Sequence[_Number] = []
53+
length: Sequence[_Number] = []
5354

5455
vector_space = 1
5556
for i in range(0, dimension):
56-
if list_like(position_range[i]):
57-
if position_range[i][1] < position_range[i][0]:
57+
now_position_range = position_range[i]
58+
if isinstance(now_position_range, tuple):
59+
if now_position_range[1] < now_position_range[0]:
5860
raise ValueError(
5961
"upper-bound should be larger than lower-bound")
60-
offset.append(position_range[i][0])
61-
length.append(position_range[i][1] - position_range[i][0])
62+
offset.append(now_position_range[0])
63+
length.append(now_position_range[1] - now_position_range[0])
6264
else:
6365
offset.append(0)
64-
length.append(position_range[i])
66+
length.append(now_position_range)
6567
vector_space *= (length[i] + 1)
6668

6769
if mode == VectorRandomMode.unique and num > vector_space:
6870
raise ValueError(
6971
"1st param is so large that CYaRon can not generate unique vectors"
7072
)
7173

74+
result: Union[List[List[int]], List[List[float]]]
7275
if mode == VectorRandomMode.repeatable:
76+
offset = typecast(Sequence[int], offset)
77+
length = typecast(Sequence[int], length)
7378
result = [[
7479
random.randint(x, x + y) for x, y in zip(offset, length)
7580
] for _ in range(num)]
@@ -79,8 +84,11 @@ def random(
7984
] for _ in range(num)]
8085
elif mode == VectorRandomMode.unique and vector_space > 5 * num:
8186
# O(NlogN)
87+
offset = typecast(Sequence[int], offset)
88+
length = typecast(Sequence[int], length)
89+
vector_space = typecast(int, vector_space)
8290
num_set: Set[int] = set()
83-
result: List[List[int]] = []
91+
result = typecast(List[List[int]], [])
8492
for i in range(0, num):
8593
while True:
8694
rand = random.randint(0, vector_space - 1)
@@ -93,6 +101,9 @@ def random(
93101
result.append(tmp)
94102
else:
95103
# generate 0~vector_space and shuffle
104+
offset = typecast(Sequence[int], offset)
105+
length = typecast(Sequence[int], length)
106+
vector_space = typecast(int, vector_space)
96107
rand_arr = list(range(0, vector_space))
97108
random.shuffle(rand_arr)
98109
result = [
@@ -106,13 +117,14 @@ def random(
106117
return result
107118

108119
@staticmethod
109-
def get_vector(dimension: int, position_range: list, hashcode: int):
120+
def get_vector(dimension: int, position_range: Sequence[int],
121+
hashcode: int):
110122
"""
111123
Generates a vector based on the given dimension, position range, and hashcode.
112124
Args:
113-
dimension (int): The number of dimensions for the vector.
114-
position_range (list): A list of integers specifying the range for each dimension.
115-
hashcode (int): A hashcode used to generate the vector.
125+
dimension: The number of dimensions for the vector.
126+
position_range: A list of integers specifying the range for each dimension.
127+
hashcode: A hashcode used to generate the vector.
116128
Returns:
117129
list: A list representing the generated vector.
118130
"""

0 commit comments

Comments
 (0)