diff --git a/python/private/pypi/parse_requirements.bzl b/python/private/pypi/parse_requirements.bzl index 9c610f11d3..d6a7d9d2d7 100644 --- a/python/private/pypi/parse_requirements.bzl +++ b/python/private/pypi/parse_requirements.bzl @@ -83,12 +83,106 @@ def parse_requirements( The second element is extra_pip_args should be passed to `whl_library`. """ - evaluate_markers = evaluate_markers or (lambda _ctx, _requirements: {}) + b = parse_requirements_builder( + evaluate_markers = evaluate_markers, + get_index_urls = get_index_urls, + extract_url_srcs = extract_url_srcs, + logger = logger, + ) + for f, plats in requirements_by_platform.items(): + for p in plats: + b = b.with_requirement(f, p) + b = b.with_pip_args(extra_pip_args) + return b.build(ctx) + +def parse_requirements_builder( + *, + evaluate_markers = None, + get_index_urls = None, + extract_url_srcs = True, + logger = None): + """Create a builder for incremental configuration of the parsing. + + Args: + get_index_urls: Callable[[ctx, list[str]], dict], a callable to get all + of the distribution URLs from a PyPI index. Accepts ctx and + distribution names to query. + evaluate_markers: A function to use to evaluate the requirements. + Accepts a dict where keys are requirement lines to evaluate against + the platforms stored as values in the input dict. Returns the same + dict, but with values being platforms that are compatible with the + requirements line. + extract_url_srcs: A boolean to enable extracting URLs from requirement + lines to enable using bazel downloader. + logger: repo_utils.logger or None, a simple struct to log diagnostic messages. + + Returns: + A builder with methods: + * with_requirement - add a requirement to be included in building. + * with_pip_args - add pip args to be included in building. + * build - parse the requirements and return the appropriate parameters to create the whl_libraries. + """ + + # buildifier: disable=uninitialized + self = struct( + # buildable components + requirements_by_platform = {}, + extra_pip_args = {}, + # other params + evaluate_markers = evaluate_markers or (lambda _ctx, _requirements: {}), + get_index_urls = get_index_urls, + extract_url_srcs = extract_url_srcs, + logger = logger, + # go/keep-sorted start + build = lambda ctx: _builder_build(self, ctx), + with_requirement = lambda file, platform: _builder_with_requirement(self, file, platform), + with_pip_args = lambda args: _builder_with_pip_args(self, args), + # go/keep-sorted end + ) + return self + +def _builder_with_requirement(self, file, platform): + """Add a requirement""" + self.requirements_by_platform.setdefault(file, []).append(platform) + return self + +def _builder_with_pip_args(self, args, platform = None): + """Add pip arguments by platform""" + self.extra_pip_args.setdefault(platform, []).extend(args) + return self + +def _builder_build(self, ctx): + """Get the requirements with platforms that the requirements apply to. + + Args: + self: The builder instance. + ctx: A context that has .read function that would read contents from a label. + + Returns: + {type}`dict[str, list[struct]]` where the key is the distribution name and the struct + contains the following attributes: + * `distribution`: {type}`str` The non-normalized distribution name. + * `srcs`: {type}`struct` The parsed requirement line for easier Simple + API downloading (see `index_sources` return value). + * `target_platforms`: {type}`list[str]` Target platforms that this package is for. + The format is `cp3{minor}_{os}_{arch}`. + * `is_exposed`: {type}`bool` `True` if the package should be exposed via the hub + repository. + * `extra_pip_args`: {type}`list[str]` pip args to use in case we are + not using the bazel downloader to download the archives. This should + be passed to {obj}`whl_library`. + * `whls`: {type}`list[struct]` The list of whl entries that can be + downloaded using the bazel downloader. + * `sdist`: {type}`list[struct]` The sdist that can be downloaded using + the bazel downloader. + + The second element is extra_pip_args should be passed to `whl_library`. + """ options = {} requirements = {} - for file, plats in requirements_by_platform.items(): - if logger: - logger.debug(lambda: "Using {} for {}".format(file, plats)) + for file, plats in self.requirements_by_platform.items(): + if self.logger: + self.logger.debug(lambda: "Using {} for {}".format(file, plats)) contents = ctx.read(file) # Parse the requirements file directly in starlark to get the information @@ -121,10 +215,10 @@ def parse_requirements( for p in opt.split(" "): tokenized_options.append(p) - pip_args = tokenized_options + extra_pip_args + pip_args = tokenized_options + self.extra_pip_args[None] for plat in plats: requirements[plat] = requirements_dict.values() - options[plat] = pip_args + options[plat] = pip_args + self.extra_pip_args.get(plat, []) requirements_by_platform = {} reqs_with_env_markers = {} @@ -159,16 +253,16 @@ def parse_requirements( # to do, we could use Python to parse the requirement lines and infer the # URL of the files to download things from. This should be important for # VCS package references. - env_marker_target_platforms = evaluate_markers(ctx, reqs_with_env_markers) - if logger: - logger.debug(lambda: "Evaluated env markers from:\n{}\n\nTo:\n{}".format( + env_marker_target_platforms = self.evaluate_markers(ctx, reqs_with_env_markers) + if self.logger: + self.logger.debug(lambda: "Evaluated env markers from:\n{}\n\nTo:\n{}".format( reqs_with_env_markers, env_marker_target_platforms, )) index_urls = {} - if get_index_urls: - index_urls = get_index_urls( + if self.get_index_urls: + index_urls = self.get_index_urls( ctx, # Use list({}) as a way to have a set list({ @@ -197,20 +291,20 @@ def parse_requirements( reqs = reqs, index_urls = index_urls, env_marker_target_platforms = env_marker_target_platforms, - extract_url_srcs = extract_url_srcs, - logger = logger, + extract_url_srcs = self.extract_url_srcs, + logger = self.logger, ), ) ret.append(item) - if not item.is_exposed and logger: - logger.debug(lambda: "Package '{}' will not be exposed because it is only present on a subset of platforms: {} out of {}".format( + if not item.is_exposed and self.logger: + self.logger.debug(lambda: "Package '{}' will not be exposed because it is only present on a subset of platforms: {} out of {}".format( name, sorted(requirement_target_platforms), sorted(requirements), )) - if logger: - logger.debug(lambda: "Will configure whl repos: {}".format([w.name for w in ret])) + if self.logger: + self.logger.debug(lambda: "Will configure whl repos: {}".format([w.name for w in ret])) return ret