|
| 1 | +include("file_io.jl") |
| 2 | +using .FileIO |
| 3 | + |
| 4 | +include("functions.jl") |
| 5 | +using .Functions |
| 6 | + |
| 7 | +include("error_calculator.jl") |
| 8 | +using .Err |
| 9 | + |
| 10 | +using Printf |
| 11 | +using Format |
| 12 | +using Match |
| 13 | +using Logging |
| 14 | + |
| 15 | + |
| 16 | +# Disable the bellow line to enable overflow warnings in certain mathematical |
| 17 | +# function tests. |
| 18 | +Logging.disable_logging(Logging.Warn) |
| 19 | + |
| 20 | +mutable struct PaddedFloat64 |
| 21 | + val::Float64 |
| 22 | + padding::NTuple{7, Float64} |
| 23 | +end |
| 24 | + |
| 25 | +function check_run_time(ns_budget, start_float, end_float, t_num_floats, |
| 26 | + func_name, data_format, rounding, fastmath_on, |
| 27 | + search) |
| 28 | + # Measure the time taken to test one invocation of the function and |
| 29 | + # calculate the number of tests to do according to the search strategy |
| 30 | + # duration chosen by the user. |
| 31 | + start_time = time_ns() |
| 32 | + spawn_threads(start_float, end_float, t_num_floats, |
| 33 | + func_name, data_format, rounding, fastmath_on, |
| 34 | + 100000, search) |
| 35 | + end_time = time_ns() |
| 36 | + elapsed_time = (end_time - start_time)/100000 |
| 37 | + return floor(ns_budget/elapsed_time) |
| 38 | +end |
| 39 | + |
| 40 | +max_error = |
| 41 | + [PaddedFloat64(0.0, ntuple(_ -> 0.0, 7)) for _ in 1: Threads.nthreads()+1] |
| 42 | +max_input = |
| 43 | + [PaddedFloat64(0.0, ntuple(_ -> 0.0, 7)) for _ in 1: Threads.nthreads()+1] |
| 44 | +max_output = |
| 45 | + [PaddedFloat64(0.0, ntuple(_ -> 0.0, 7)) for _ in 1: Threads.nthreads()+1] |
| 46 | +max_ref_out = |
| 47 | + [PaddedFloat64(0.0, ntuple(_ -> 0.0, 7)) for _ in 1: Threads.nthreads()+1] |
| 48 | +number_of_tests = |
| 49 | + [PaddedFloat64(0.0, ntuple(_ -> 0.0, 7)) for _ in 1: Threads.nthreads()+1] |
| 50 | + |
| 51 | +thread_tasks = Task[] |
| 52 | + |
| 53 | +function spawn_threads(start_float, end_float, t_num_floats, |
| 54 | + func_name, data_format, rounding, fastmath_on, |
| 55 | + tests_to_do, search) |
| 56 | + |
| 57 | + for tn in 1:Threads.nthreads() |
| 58 | + |
| 59 | + # Calculate the sub-interval ends for this thread. |
| 60 | + sub_int_start = Err.nextfloatn(start_float, (tn-1)*t_num_floats, data_format) |
| 61 | + if tn==Threads.nthreads() |
| 62 | + sub_int_end = end_float |
| 63 | + else |
| 64 | + sub_int_end = Err.nextfloatn(sub_int_start, t_num_floats-1, data_format) |
| 65 | + end |
| 66 | + |
| 67 | + # Run all intervals except the last in the separate threads. |
| 68 | + if (tn != Threads.nthreads()) |
| 69 | + if search == "exhaustive" |
| 70 | + t = Threads.@spawn (max_error[tn].val, |
| 71 | + max_input[tn].val, |
| 72 | + max_output[tn].val, |
| 73 | + max_ref_out[tn].val, |
| 74 | + number_of_tests[tn].val) = |
| 75 | + Err.function_max_error_exhaustive( |
| 76 | + func_name, data_format, rounding, fastmath_on, |
| 77 | + sub_int_start, sub_int_end) |
| 78 | + else |
| 79 | + t = Threads.@spawn (max_error[tn].val, |
| 80 | + max_input[tn].val, |
| 81 | + max_output[tn].val, |
| 82 | + max_ref_out[tn].val, |
| 83 | + number_of_tests[tn].val) = |
| 84 | + Err.function_max_error_fixed_step( |
| 85 | + func_name, data_format, rounding, fastmath_on, |
| 86 | + sub_int_start, sub_int_end, tests_to_do) |
| 87 | + end |
| 88 | + push!(thread_tasks, t) |
| 89 | + else |
| 90 | + # Run the last interval in the main thread. |
| 91 | + if search == "exhaustive" |
| 92 | + (max_error[tn].val, |
| 93 | + max_input[tn].val, |
| 94 | + max_output[tn].val, |
| 95 | + max_ref_out[tn].val, |
| 96 | + number_of_tests[tn].val) = |
| 97 | + Err.function_max_error_exhaustive( |
| 98 | + func_name, data_format, rounding, fastmath_on, |
| 99 | + sub_int_start, sub_int_end) |
| 100 | + else |
| 101 | + (max_error[tn].val, |
| 102 | + max_input[tn].val, |
| 103 | + max_output[tn].val, |
| 104 | + max_ref_out[tn].val, |
| 105 | + number_of_tests[tn].val) = |
| 106 | + Err.function_max_error_fixed_step( |
| 107 | + func_name, data_format, rounding, fastmath_on, |
| 108 | + sub_int_start, sub_int_end, tests_to_do) |
| 109 | + end |
| 110 | + end |
| 111 | + end |
| 112 | + |
| 113 | + wait.(thread_tasks) |
| 114 | + |
| 115 | + return (max_error, max_input, max_output, max_ref_out, number_of_tests) |
| 116 | +end |
| 117 | + |
| 118 | +config_file = "config.json" |
| 119 | + |
| 120 | +# Read and validate testing tasks specified in the json file. |
| 121 | +tasks = FileIO.read_input_file(config_file) |
| 122 | +mkpath("output") |
| 123 | + |
| 124 | +for (task_name, task_details) in tasks |
| 125 | + |
| 126 | + printstyled("Validating task: $task_name\n", color=:blue) |
| 127 | + (data_format, search, rounding, fastmath_on) = |
| 128 | + FileIO.validate_tasks(task_details) |
| 129 | + |
| 130 | + @match rounding begin |
| 131 | + "RN" => :RoundNearest |
| 132 | + "RZ" => :RoundToZero |
| 133 | + "RD" => :RoundDown |
| 134 | + "RU" => :RoundUp |
| 135 | + end |
| 136 | + |
| 137 | + if data_format == 1 |
| 138 | + printstyled("Skipping task: $task_name\n\n", color=:red) |
| 139 | + continue |
| 140 | + end |
| 141 | + |
| 142 | + # Set MPFR global precision to be 20 bits more than the specified format's. |
| 143 | + if (data_format == "binary16") |
| 144 | + setprecision(BigFloat, 31) |
| 145 | + elseif (data_format == "binary32") |
| 146 | + setprecision(BigFloat, 44) |
| 147 | + else |
| 148 | + setprecision(BigFloat, 73) |
| 149 | + end |
| 150 | + |
| 151 | + printstyled("Format: $data_format Search: $search Rounding: \ |
| 152 | + $rounding Fastmath: $fastmath_on\n", color=:green) |
| 153 | + |
| 154 | + # Results table formatting. Each task specied in the JSON file has |
| 155 | + # a result .txt file named accordingly. |
| 156 | + fe = FormatExpr("{1:<10s} {2:>15s} {3:>30s} {4:>30s} {5:>30s} {6:>20s}\n") |
| 157 | + result_table_head = format(fe, "Function", "ULPs", "Input", "Output", |
| 158 | + "MPFR", "Tests") |
| 159 | + fe = FormatExpr("{1:<10s} {2:>15s} {3:>30s} {4:>30s} {5:>20s}\n") |
| 160 | + result_table_head_hex = format(fe, "Function", "ULPs", "Input", "Output", |
| 161 | + "Tests") |
| 162 | + open("output/$task_name.txt", "w") do file |
| 163 | + write(file, result_table_head) |
| 164 | + end |
| 165 | + open("output/HEX_$task_name.txt", "w") do file_hex |
| 166 | + write(file_hex, result_table_head_hex) |
| 167 | + end |
| 168 | + fe = FormatExpr("{1:<10s} {2:>15.10f} {3:>30.15e} \ |
| 169 | + {4:>30.15e} {5:>30.15e} {6:>20d}\n") |
| 170 | + fe_hex = FormatExpr("{1:<10s} {2:>15.10f} {3:>#30x} {4:>#30x} \ |
| 171 | + {5:>20d}\n") |
| 172 | + |
| 173 | + # Caluclate how many tests of this functions can be done in the |
| 174 | + # running time, approximately, by a single thread. |
| 175 | + ns_budget = 0 |
| 176 | + if search == "seconds" |
| 177 | + ns_budget = 10^9 |
| 178 | + elseif search == "minutes" |
| 179 | + ns_budget = 60*10^9 |
| 180 | + elseif search == "hours" |
| 181 | + ns_budget = 3600*10^9 |
| 182 | + elseif search == "days" |
| 183 | + ns_budget = 24*3600*10^9 |
| 184 | + end |
| 185 | + |
| 186 | + # Loop through the functions list of a particular format. |
| 187 | + for (func_name, v) in Functions.functions_dict[data_format] |
| 188 | + |
| 189 | + start_float = Functions.functions_dict[data_format][func_name][1] |
| 190 | + end_float = Functions.functions_dict[data_format][func_name][2] |
| 191 | + |
| 192 | + # Check the number of floating-point numbers in the function's input |
| 193 | + # domain; used in calculating sub-intervals for the different threads. |
| 194 | + num_floats = Err.number_of_floats_in_interval( |
| 195 | + start_float, end_float, data_format) |
| 196 | + t_num_floats = floor(num_floats/Threads.nthreads()) |
| 197 | + |
| 198 | + tests_to_do = 0 |
| 199 | + if search != "exhaustive" |
| 200 | + tests_to_do = check_run_time(ns_budget, start_float, end_float, t_num_floats, |
| 201 | + func_name, data_format, rounding, fastmath_on, search) |
| 202 | + end |
| 203 | + |
| 204 | + if search == "exhaustive" |
| 205 | + @printf("Running %d tests (search strategy exhaustive) for the function \ |
| 206 | + %s with %d threads \n", num_floats, func_name, Threads.nthreads()) |
| 207 | + else |
| 208 | + @printf("Running %d tests (search strategy \"%s\") for the function \ |
| 209 | + %s with %d threads \n", |
| 210 | + tests_to_do*Threads.nthreads(), search, func_name, Threads.nthreads()) |
| 211 | + end |
| 212 | + flush(stdout) |
| 213 | + |
| 214 | + spawn_threads(start_float, end_float, t_num_floats, |
| 215 | + func_name, data_format, rounding, fastmath_on, |
| 216 | + tests_to_do, search) |
| 217 | + |
| 218 | + # Run tests on special inputs |
| 219 | + if (data_format != "binary16") |
| 220 | + input_set = Functions.spec_inputs_dict[data_format][func_name] |
| 221 | + (max_error[Threads.nthreads()+1].val, |
| 222 | + max_input[Threads.nthreads()+1].val, |
| 223 | + max_output[Threads.nthreads()+1].val, |
| 224 | + max_ref_out[Threads.nthreads()+1].val, |
| 225 | + number_of_tests[Threads.nthreads()+1].val) = |
| 226 | + Err.function_max_error_special_inputs( |
| 227 | + func_name, data_format, rounding, fastmath_on, input_set) |
| 228 | + end |
| 229 | + |
| 230 | + # Find index of the maximum error and report to the output file. |
| 231 | + i = findmax([s.val for s in max_error])[2] |
| 232 | + line = format(fe, func_name, max_error[i].val, max_input[i].val, max_output[i].val, |
| 233 | + max_ref_out[i].val, sum([s.val for s in number_of_tests])) |
| 234 | + open("output/$task_name.txt", "a") do file |
| 235 | + write(file, line); |
| 236 | + end |
| 237 | + line = format(fe_hex, func_name, max_error[i].val, |
| 238 | + reinterpret(Err.uint_formats[data_format], |
| 239 | + Err.formats[data_format](max_input[i].val)), |
| 240 | + reinterpret(Err.uint_formats[data_format], |
| 241 | + Err.formats[data_format](max_output[i].val)), |
| 242 | + sum([s.val for s in number_of_tests])) |
| 243 | + open("output/HEX_$task_name.txt", "a") do file_hex |
| 244 | + write(file_hex, line); |
| 245 | + end |
| 246 | + end |
| 247 | +end |
0 commit comments