|
5 | 5 | import argparse |
6 | 6 | parser = argparse.ArgumentParser() |
7 | 7 | parser.add_argument("custom_pytorch_path", help="Path to custom PyTorch wheel") |
| 8 | +parser.add_argument("custom_bitorch_engine_path", help="Path to built bitorch engine wheel file") |
8 | 9 | args = parser.parse_args() |
9 | 10 |
|
10 | | -BLOCK_HEADER_START = "#### Conda on Linux" |
| 11 | +BLOCK_HEADER_START_BINARY = "### Binary Release" |
| 12 | +BLOCK_HEADER_START_FROM_SOURCE = "#### Conda on Linux" |
| 13 | +BLOCK_END = "##########" |
11 | 14 |
|
12 | 15 | with open("README.md") as infile: |
13 | 16 | content = infile.readlines() |
14 | 17 |
|
15 | | -local_install_instructions = [] |
16 | | -global_install_instructions = [] |
| 18 | +with open(".dev-scripts/basic_tests.sh") as infile: |
| 19 | + test_appendix = infile.readlines() |
| 20 | + |
| 21 | + |
| 22 | +def write_file(filepath, main_content): |
| 23 | + with open(filepath, "w") as outfile: |
| 24 | + outfile.write(FILE_INTRO) |
| 25 | + outfile.writelines(main_content) |
| 26 | + outfile.writelines(test_appendix) |
| 27 | + |
| 28 | + |
| 29 | +source_local_install_instructions = [] |
| 30 | +source_global_install_instructions = [] |
| 31 | +binary_local_install_instructions = [] |
| 32 | +binary_global_install_instructions = [] |
17 | 33 |
|
18 | 34 | in_code_block = False |
19 | 35 | reading_instructions = False |
|
36 | 52 | if line.startswith("```"): |
37 | 53 | in_code_block = not in_code_block |
38 | 54 | continue |
39 | | - if line.startswith(BLOCK_HEADER_START): |
| 55 | + if line.startswith(BLOCK_HEADER_START_FROM_SOURCE): |
40 | 56 | reading_instructions = True |
41 | | - instruction_type = "global" |
| 57 | + instruction_type = "source-global" |
| 58 | + BLOCK_END = BLOCK_HEADER_START_FROM_SOURCE.split()[0] |
| 59 | + continue |
| 60 | + if line.startswith(BLOCK_HEADER_START_BINARY): |
| 61 | + reading_instructions = True |
| 62 | + instruction_type = "binary-global" |
| 63 | + BLOCK_END = BLOCK_HEADER_START_BINARY.split()[0] |
42 | 64 | continue |
43 | 65 | if line.startswith("<details><summary>"): |
44 | | - instruction_type = "local" |
| 66 | + if "source" in instruction_type: |
| 67 | + instruction_type = "source-local" |
| 68 | + if "binary" in instruction_type: |
| 69 | + instruction_type = "binary-local" |
45 | 70 | continue |
46 | 71 | if line.startswith("</details>"): |
47 | | - instruction_type = "both" |
| 72 | + if "source" in instruction_type: |
| 73 | + instruction_type = "source-both" |
| 74 | + if "binary" in instruction_type: |
| 75 | + instruction_type = "binary-both" |
48 | 76 | continue |
49 | | - if line.startswith(BLOCK_HEADER_START.split()[0]): |
| 77 | + if line.startswith(BLOCK_END): |
50 | 78 | reading_instructions = False |
51 | 79 | continue |
52 | 80 | if not reading_instructions: |
|
68 | 96 | line = line.replace("${HOME}", "$(pwd)") |
69 | 97 | if line.startswith("pip install torch-"): |
70 | 98 | line = "pip install {}\n".format(args.custom_pytorch_path) |
| 99 | + if line.startswith("pip install bitorch_engine"): |
| 100 | + line = "pip install {}\n".format(args.custom_bitorch_engine_path) |
71 | 101 |
|
72 | 102 | # decide how to write line |
73 | 103 | line_format = "{line}" |
|
78 | 108 | line_format = "\n" + line_format |
79 | 109 |
|
80 | 110 | # write result line(s) |
81 | | - if instruction_type == "global" or instruction_type == "both": |
82 | | - global_install_instructions.append(line_format.format(line=line)) |
83 | | - if instruction_type == "local" or instruction_type == "both": |
84 | | - local_install_instructions.append(line_format.format(line=line)) |
85 | | - |
86 | | -with open(".dev-scripts/test_local_conda_install.sh", "w") as outfile: |
87 | | - outfile.write(FILE_INTRO) |
88 | | - outfile.writelines(local_install_instructions) |
89 | | -with open(".dev-scripts/test_global_conda_install.sh", "w") as outfile: |
90 | | - outfile.write(FILE_INTRO) |
91 | | - outfile.writelines(global_install_instructions) |
| 111 | + if instruction_type == "source-global" or instruction_type == "source-both": |
| 112 | + source_global_install_instructions.append(line_format.format(line=line)) |
| 113 | + if instruction_type == "source-local" or instruction_type == "source-both": |
| 114 | + source_local_install_instructions.append(line_format.format(line=line)) |
| 115 | + if instruction_type == "binary-global" or instruction_type == "binary-both": |
| 116 | + binary_global_install_instructions.append(line_format.format(line=line)) |
| 117 | + if instruction_type == "binary-local" or instruction_type == "binary-both": |
| 118 | + binary_local_install_instructions.append(line_format.format(line=line)) |
| 119 | + |
| 120 | +write_file(".dev-scripts/test_source_local_conda_install.sh", source_local_install_instructions) |
| 121 | +write_file(".dev-scripts/test_source_global_conda_install.sh", source_global_install_instructions) |
| 122 | +write_file(".dev-scripts/test_binary_local_conda_install.sh", binary_local_install_instructions) |
| 123 | +write_file(".dev-scripts/test_binary_global_conda_install.sh", binary_global_install_instructions) |
| 124 | + |
| 125 | +binary_local_cu118 = [line.replace("cu121", "cu118").replace("cuda-12.1.0", "cuda-11.8.0") for line in binary_local_install_instructions] |
| 126 | +write_file(".dev-scripts/test_binary_local_conda_install_cu118.sh", binary_local_cu118) |
| 127 | +binary_local_no_cuda = filter(lambda x: "nvidia/label/cuda-12.1.0" not in x, binary_local_install_instructions) |
| 128 | +write_file(".dev-scripts/test_binary_local_conda_install_no_cuda.sh", binary_local_no_cuda) |
0 commit comments