Skip to content

Commit 303454c

Browse files
Merge pull request #2 from BrandonGroth/mx_impl_patch
feat: Allow patching for MX library
2 parents 40a5c0c + 352e77d commit 303454c

File tree

3 files changed

+136
-0
lines changed

3 files changed

+136
-0
lines changed

install_patches.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Standard
2+
import os
3+
import subprocess
4+
5+
dependencies_with_patch = {
6+
"microxcaling": "https://github.com/microsoft/microxcaling.git",
7+
}
8+
9+
10+
def install_with_patch(
11+
pkg_name: str,
12+
repo_url: str,
13+
patch_file: str,
14+
home_dir: str = None,
15+
) -> None:
16+
"""
17+
Install a dependency with a patch file
18+
19+
Args:
20+
pkg_name (str): Name of package being installed
21+
repo_url (str): Github repo URL
22+
patch_file (str): Patch file in patches/<patch_file>
23+
home_dir (str): Home directory with fms-model-optimizer and other packages.
24+
Defaults to None.
25+
"""
26+
# We want to git clone the repo to $HOME/repo_name
27+
if home_dir is None:
28+
home_dir = os.path.expanduser("~")
29+
30+
# Get fms_mo directory in home_dir
31+
cwd = os.getcwd()
32+
33+
# Get patch file location from fms-model-optimizer
34+
patch_file = os.path.join(cwd, "patches", patch_file)
35+
if not os.path.exists(patch_file):
36+
raise FileNotFoundError(f"Can't find {pkg_name} patch file in {cwd}/patches")
37+
38+
# Check to see if package exists in cwd or home_dir
39+
pkg_path_cwd = os.path.join(cwd, pkg_name)
40+
pkg_path_home = os.path.join(home_dir, pkg_name)
41+
pkg_exists_cwd = os.path.exists(pkg_path_cwd)
42+
pkg_exists_home = os.path.exists(pkg_path_home)
43+
44+
# If pkg already exists in cwd or home_dir, skip clone
45+
if pkg_exists_cwd:
46+
pkg_dir = pkg_path_cwd
47+
print(f"Directory {pkg_dir} already exists. Skipping download.")
48+
elif pkg_exists_home:
49+
pkg_dir = pkg_path_home
50+
print(f"Directory {pkg_dir} already exists. Skipping download.")
51+
else:
52+
# Clone repo to home directory
53+
pkg_dir = pkg_path_home
54+
subprocess.run(["git", "clone", repo_url], cwd=home_dir, check=True)
55+
56+
# Apply patch and pip install package
57+
try:
58+
subprocess.run(["git", "apply", "--check", patch_file], cwd=pkg_dir, check=True)
59+
subprocess.run(["git", "apply", patch_file], cwd=pkg_dir, check=True)
60+
print(
61+
f"FMS Model Optimizer patch for {pkg_name} applied. Installing package now."
62+
)
63+
subprocess.run(["pip", "install", "."], cwd=pkg_dir, check=True)
64+
65+
except subprocess.CalledProcessError as e:
66+
print(
67+
f"FMS Model Optimizer patch for {pkg_name} is already installed "
68+
f"or an error has occured: \n{e}"
69+
)
70+
71+
72+
def install_dependencies_with_patch() -> None:
73+
"""
74+
Script to install depenencies that requires a patch prior to pip install.
75+
76+
To execute, use `python install_patches.py`.
77+
78+
Requirements:
79+
1. The patch file is named <package>.patch
80+
2. Patch file must be located in fms-model-optimizer/patches
81+
"""
82+
for pkg, repo_url in dependencies_with_patch.items():
83+
install_with_patch(
84+
pkg_name=pkg,
85+
repo_url=repo_url,
86+
patch_file=pkg + ".patch",
87+
)
88+
89+
90+
if __name__ == "__main__":
91+
install_dependencies_with_patch()

patches/README.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
## Patching Third Party Dependencies
2+
Some dependencies clash with the current FMS Model Optimizer environment and we need to apply a patch.
3+
To do this, we have provided a script in `fms-model-optimizer` named `install_patches.py`.
4+
5+
To run this script:
6+
```
7+
python3 install_patches.py
8+
```
9+
10+
The following optional packages require a patch:
11+
* `microxcaling`: Uses outdated versions of PyTorch-related packages
12+
13+
## Making a Patch File
14+
To make a git diff patch file, first make your desired changes to the repository. Then run
15+
```
16+
git diff > <package>.patch
17+
```
18+
Packages may include files that differ by whitespaces even if you didn't change them.
19+
To address this, add `--ignore-all-spaces` to the `git diff` command.
20+
21+
To test the patch file, copy the `<package>.patch` file to `fms-model-optimizer/patches`.
22+
Next add a new entry to the `install_patches.py` dictionary called `dependencies_with_patch` with the package name and repo URL:
23+
```
24+
dependencies_with_patch = {
25+
<package>: <URL>, # for <package>.patch
26+
}
27+
```
28+
Lastly, run the python command for `install_patches.py`.

patches/microxcaling.patch

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
diff --git a/pyproject.toml b/pyproject.toml
2+
index e80053e..b4ec100 100644
3+
--- a/pyproject.toml
4+
+++ b/pyproject.toml
5+
@@ -5,9 +5,9 @@ description = 'The Microsoft MX floating point library'
6+
readme = "README.md"
7+
requires-python = ">=3.8"
8+
dependencies = [
9+
- "torch==2.2.0",
10+
- "torchvision==0.16",
11+
- "torchaudio==2.1.0"
12+
+ "torch",
13+
+ "torchvision",
14+
+ "torchaudio"
15+
]
16+
license = { file = "LICENSE" }
17+
keywords = ["mx", "floating point", "math", "mathematics", "machine learning", "deep learning", "artificial intelligence", "ai", "ml", "dl", "torch", "torchvision", "torchaudio"]

0 commit comments

Comments
 (0)