88import tempfile
99from enum import IntEnum
1010from functools import lru_cache
11+ from itertools import chain
1112from pathlib import Path
1213from typing import TYPE_CHECKING , Any , NotRequired , Required , TypedDict
1314
@@ -57,28 +58,32 @@ def create_version_pattern(name: str) -> re.Pattern[str]:
5758
5859@lru_cache
5960def resolve_pyproject (
60- pyproject : str | PathLike [str ], temp_directory : str | PathLike [str ]
61+ pyproject : str | PathLike [str ],
62+ temp_directory : str | PathLike [str ],
63+ extras : tuple [str , ...],
6164) -> Path :
6265 origin_pyproject , temp_directory = Path (pyproject ), Path (temp_directory )
6366 new_pyproject = temp_directory / "pyproject.toml"
6467
6568 key , new_pyproject = combine_dev_dependencies (origin_pyproject , new_pyproject )
69+ command = [
70+ "uv" ,
71+ "pip" ,
72+ "compile" ,
73+ new_pyproject .name ,
74+ "-o" ,
75+ "requirements.txt" ,
76+ "--extra" ,
77+ key ,
78+ ]
79+ if extras :
80+ command = [
81+ * command ,
82+ * chain .from_iterable (("--extra" , extra ) for extra in extras ),
83+ ]
6684
6785 uv_process = subprocess .run ( # noqa: S603
68- [ # noqa: S607
69- "uv" ,
70- "pip" ,
71- "compile" ,
72- new_pyproject .name ,
73- "-o" ,
74- "requirements.txt" ,
75- "--extra" ,
76- key ,
77- ],
78- cwd = temp_directory ,
79- check = False ,
80- capture_output = True ,
81- text = True ,
86+ command , cwd = temp_directory , check = False , capture_output = True , text = True
8287 )
8388 try :
8489 uv_process .check_returncode ()
@@ -113,9 +118,12 @@ def find_version(name: str, lock_file: str | PathLike[str]) -> str:
113118
114119
115120def find_version_in_pyproject (
116- name : str , pyproject : str | PathLike [str ], temp_directory : str | PathLike [str ]
121+ name : str ,
122+ pyproject : str | PathLike [str ],
123+ temp_directory : str | PathLike [str ],
124+ extras : tuple [str , ...],
117125) -> str :
118- lock_file = resolve_pyproject (pyproject , temp_directory )
126+ lock_file = resolve_pyproject (pyproject , temp_directory , extras )
119127 return find_version (name , lock_file )
120128
121129
@@ -155,7 +163,10 @@ def resolve_arg(arg_string: str) -> Args:
155163
156164
157165def process (
158- args : list [Args ], pyproject : str | PathLike [str ], pre_commit : str | PathLike [str ]
166+ args : list [Args ],
167+ pyproject : str | PathLike [str ],
168+ pre_commit : str | PathLike [str ],
169+ extras : tuple [str , ...],
159170) -> None :
160171 if args :
161172 logger .info ("Processing args:" )
@@ -168,7 +179,9 @@ def process(
168179 errors : list [tuple [str , str , str , str ] | None ] = [None ] * len (args )
169180 with tempfile .TemporaryDirectory () as temp_directory :
170181 for index , arg in enumerate (args ):
171- version = find_version_in_pyproject (arg ["name" ], pyproject , temp_directory )
182+ version = find_version_in_pyproject (
183+ arg ["name" ], pyproject , temp_directory , extras
184+ )
172185 version = f"{ arg .get ("prefix" , "" )} { version } { arg .get ("suffix" , "" )} "
173186
174187 if hooks [arg ["hook_id" ]] == version :
@@ -208,15 +221,16 @@ def _main() -> None:
208221 "-P" , "--pre-commit" , type = str , default = ".pre-commit-config.yaml"
209222 )
210223 parser .add_argument ("-l" , "--log-level" , type = str , default = "INFO" )
224+ parser .add_argument ("-e" , "--extra" , action = "append" , default = [])
211225 parser .add_argument ("dummy" , nargs = "*" )
212226
213227 args = parser .parse_args ()
214228 args_string : list [str ] = args .args
215229 version_args = [resolve_arg (arg ) for arg in args_string ]
216- pyproject , pre_commit = args .pyproject , args .pre_commit
230+ pyproject , pre_commit , extras = args .pyproject , args .pre_commit , args . extra
217231 logger .setLevel (args .log_level )
218232
219- process (version_args , pyproject , pre_commit )
233+ process (version_args , pyproject , pre_commit , tuple ( extras ) )
220234
221235
222236def main () -> None :
0 commit comments