Skip to content

Commit 7e3c1d5

Browse files
authored
Format code on master (black + isort) (#538)
* Config files * Add autoflake * Update isort exclude; add pre-commit to requirements * Manually fix a few bad cases
1 parent 17ad489 commit 7e3c1d5

File tree

324 files changed

+5620
-3595
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

324 files changed

+5620
-3595
lines changed

.github/linters/pyproject.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[tool.black]
2+
line-length = 120
3+
4+
[tool.isort]
5+
profile = "black"
6+
line_length = 120
7+
known_first_party = "maro"

.github/linters/tox.ini

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ ignore =
55
# line break after binary operator
66
W504,
77
# line break before binary operator
8-
W503
8+
W503,
9+
# whitespace before ':'
10+
E203
911

1012
exclude =
1113
.git,
@@ -27,14 +29,5 @@ max-line-length = 120
2729
per-file-ignores =
2830
# import not used: ignore in __init__.py files
2931
__init__.py:F401
30-
# igore invalid escape sequence in cli main script to show banner
32+
# ignore invalid escape sequence in cli main script to show banner
3133
maro.py:W605
32-
33-
[isort]
34-
indent = " "
35-
line_length = 120
36-
use_parentheses = True
37-
multi_line_output = 6
38-
known_first_party = maro
39-
filter_files = True
40-
skip_glob = maro/__init__.py, tests/*, examples/*, setup.py

.github/workflows/lint.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ jobs:
4545
uses: github/super-linter@latest
4646
env:
4747
VALIDATE_ALL_CODEBASE: false
48-
VALIDATE_PYTHON_PYLINT: false # disable pylint, as we have not configure it
49-
VALIDATE_PYTHON_BLACK: false # same as above
48+
VALIDATE_PYTHON_PYLINT: false # disable pylint, as we have not configured it
5049
VALIDATE_PYTHON_MYPY: false # same as above
5150
VALIDATE_JSCPD: false # Can not exclude specific file: https://github.com/kucherenko/jscpd/issues/215
5251
PYTHON_FLAKE8_CONFIG_FILE: tox.ini
53-
PYTHON_ISORT_CONFIG_FILE: tox.ini
52+
PYTHON_BLACK_CONFIG_FILE: pyproject.toml
53+
PYTHON_ISORT_CONFIG_FILE: pyproject.toml
5454
EDITORCONFIG_FILE_NAME: ../../.editorconfig
5555
FILTER_REGEX_INCLUDE: maro/.*
5656
FILTER_REGEX_EXCLUDE: tests/.*

.pre-commit-config.yaml

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
repos:
5+
- repo: https://github.com/myint/autoflake
6+
rev: v1.4
7+
hooks:
8+
- id: autoflake
9+
args:
10+
- --in-place
11+
- --remove-unused-variables
12+
- --remove-all-unused-imports
13+
exclude: .*/__init__\.py|setup\.py
14+
- repo: https://github.com/pycqa/isort
15+
rev: 5.10.1
16+
hooks:
17+
- id: isort
18+
args:
19+
- --settings-path=.github/linters/pyproject.toml
20+
- --check
21+
- repo: https://github.com/asottile/add-trailing-comma
22+
rev: v2.2.3
23+
hooks:
24+
- id: add-trailing-comma
25+
name: add-trailing-comma (1st round)
26+
- repo: https://github.com/psf/black
27+
rev: 22.3.0
28+
hooks:
29+
- id: black
30+
name: black (1st round)
31+
args:
32+
- --config=.github/linters/pyproject.toml
33+
- repo: https://github.com/asottile/add-trailing-comma
34+
rev: v2.2.3
35+
hooks:
36+
- id: add-trailing-comma
37+
name: add-trailing-comma (2nd round)
38+
- repo: https://github.com/psf/black
39+
rev: 22.3.0
40+
hooks:
41+
- id: black
42+
name: black (2nd round)
43+
args:
44+
- --config=.github/linters/pyproject.toml
45+
- repo: https://gitlab.com/pycqa/flake8
46+
rev: 3.7.9
47+
hooks:
48+
- id: flake8
49+
args:
50+
- --config=.github/linters/tox.ini
51+
exclude: \.git|__pycache__|docs|build|dist|.*\.egg-info|docker_files|\.vscode|\.github|scripts|tests|maro\/backends\/.*.cp|setup.py

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565

6666

6767
# The name of the Pygments (syntax highlighting) style to use.
68-
pygments_style = 'sphinx'
68+
pygments_style = "sphinx"
6969

7070
# If true, `todo` and `todoList` produce output, else they produce nothing.
7171
todo_include_todos = False

examples/cim/rl/algorithms/ac.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT license.
33

4-
from typing import Dict
5-
64
import torch
75
from torch.optim import Adam, RMSprop
86

97
from maro.rl.model import DiscreteACBasedNet, FullyConnected, VNet
108
from maro.rl.policy import DiscretePolicyGradient
11-
from maro.rl.training.algorithms import ActorCriticTrainer, ActorCriticParams
9+
from maro.rl.training.algorithms import ActorCriticParams, ActorCriticTrainer
1210

1311
actor_net_conf = {
1412
"hidden_dims": [256, 128, 64],
@@ -58,10 +56,10 @@ def get_ac(state_dim: int, name: str) -> ActorCriticTrainer:
5856
name=name,
5957
params=ActorCriticParams(
6058
get_v_critic_net_func=lambda: MyCriticNet(state_dim),
61-
reward_discount=.0,
59+
reward_discount=0.0,
6260
grad_iters=10,
6361
critic_loss_cls=torch.nn.SmoothL1Loss,
6462
min_logp=None,
65-
lam=.0,
63+
lam=0.0,
6664
),
6765
)

examples/cim/rl/algorithms/dqn.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT license.
33

4-
from typing import Dict
5-
64
import torch
75
from torch.optim import RMSprop
86

97
from maro.rl.exploration import MultiLinearExplorationScheduler, epsilon_greedy
108
from maro.rl.model import DiscreteQNet, FullyConnected
119
from maro.rl.policy import ValueBasedPolicy
12-
from maro.rl.training.algorithms import DQNTrainer, DQNParams
10+
from maro.rl.training.algorithms import DQNParams, DQNTrainer
1311

1412
q_net_conf = {
1513
"hidden_dims": [256, 128, 64, 32],
@@ -38,14 +36,18 @@ def get_dqn_policy(state_dim: int, action_num: int, name: str) -> ValueBasedPoli
3836
name=name,
3937
q_net=MyQNet(state_dim, action_num),
4038
exploration_strategy=(epsilon_greedy, {"epsilon": 0.4}),
41-
exploration_scheduling_options=[(
42-
"epsilon", MultiLinearExplorationScheduler, {
43-
"splits": [(2, 0.32)],
44-
"initial_value": 0.4,
45-
"last_ep": 5,
46-
"final_value": 0.0,
47-
}
48-
)],
39+
exploration_scheduling_options=[
40+
(
41+
"epsilon",
42+
MultiLinearExplorationScheduler,
43+
{
44+
"splits": [(2, 0.32)],
45+
"initial_value": 0.4,
46+
"last_ep": 5,
47+
"final_value": 0.0,
48+
},
49+
),
50+
],
4951
warmup=100,
5052
)
5153

@@ -54,7 +56,7 @@ def get_dqn(name: str) -> DQNTrainer:
5456
return DQNTrainer(
5557
name=name,
5658
params=DQNParams(
57-
reward_discount=.0,
59+
reward_discount=0.0,
5860
update_target_every=5,
5961
num_epochs=10,
6062
soft_update_coef=0.1,

examples/cim/rl/algorithms/maddpg.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,29 @@
22
# Licensed under the MIT license.
33

44
from functools import partial
5-
from typing import Dict, List
5+
from typing import List
66

77
import torch
88
from torch.optim import Adam, RMSprop
99

1010
from maro.rl.model import DiscreteACBasedNet, FullyConnected, MultiQNet
1111
from maro.rl.policy import DiscretePolicyGradient
12-
from maro.rl.training.algorithms import DiscreteMADDPGTrainer, DiscreteMADDPGParams
13-
12+
from maro.rl.training.algorithms import DiscreteMADDPGParams, DiscreteMADDPGTrainer
1413

1514
actor_net_conf = {
1615
"hidden_dims": [256, 128, 64],
1716
"activation": torch.nn.Tanh,
1817
"softmax": True,
1918
"batch_norm": False,
20-
"head": True
19+
"head": True,
2120
}
2221
critic_net_conf = {
2322
"hidden_dims": [256, 128, 64],
2423
"output_dim": 1,
2524
"activation": torch.nn.LeakyReLU,
2625
"softmax": False,
2726
"batch_norm": True,
28-
"head": True
27+
"head": True,
2928
}
3029
actor_learning_rate = 0.001
3130
critic_learning_rate = 0.001
@@ -64,9 +63,9 @@ def get_maddpg(state_dim: int, action_dims: List[int], name: str) -> DiscreteMAD
6463
return DiscreteMADDPGTrainer(
6564
name=name,
6665
params=DiscreteMADDPGParams(
67-
reward_discount=.0,
66+
reward_discount=0.0,
6867
num_epoch=10,
6968
get_q_critic_net_func=partial(get_multi_critic_net, state_dim, action_dims),
70-
shared_critic=False
71-
)
69+
shared_critic=False,
70+
),
7271
)

examples/cim/rl/algorithms/ppo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ def get_ppo(state_dim: int, name: str) -> PPOTrainer:
1515
name=name,
1616
params=PPOParams(
1717
get_v_critic_net_func=lambda: MyCriticNet(state_dim),
18-
reward_discount=.0,
18+
reward_discount=0.0,
1919
grad_iters=10,
2020
critic_loss_cls=torch.nn.SmoothL1Loss,
2121
min_logp=None,
22-
lam=.0,
22+
lam=0.0,
2323
clip_ratio=0.1,
2424
),
2525
)

examples/cim/rl/config.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
env_conf = {
55
"scenario": "cim",
66
"topology": "toy.4p_ssdd_l0.0",
7-
"durations": 560
7+
"durations": 560,
88
}
99

1010
if env_conf["topology"].startswith("toy"):
@@ -17,27 +17,26 @@
1717

1818
state_shaping_conf = {
1919
"look_back": 7,
20-
"max_ports_downstream": 2
20+
"max_ports_downstream": 2,
2121
}
2222

2323
action_shaping_conf = {
2424
"action_space": [(i - 10) / 10 for i in range(21)],
2525
"finite_vessel_space": True,
26-
"has_early_discharge": True
26+
"has_early_discharge": True,
2727
}
2828

2929
reward_shaping_conf = {
3030
"time_window": 99,
3131
"fulfillment_factor": 1.0,
3232
"shortage_factor": 1.0,
33-
"time_decay": 0.97
33+
"time_decay": 0.97,
3434
}
3535

3636
# obtain state dimension from a temporary env_wrapper instance
37-
state_dim = (
38-
(state_shaping_conf["look_back"] + 1) * (state_shaping_conf["max_ports_downstream"] + 1) * len(port_attributes)
39-
+ len(vessel_attributes)
40-
)
37+
state_dim = (state_shaping_conf["look_back"] + 1) * (state_shaping_conf["max_ports_downstream"] + 1) * len(
38+
port_attributes,
39+
) + len(vessel_attributes)
4140

4241
action_num = len(action_shaping_conf["action_space"])
4342

0 commit comments

Comments
 (0)