|
| 1 | +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 2 | +# See https://llvm.org/LICENSE.txt for license information. |
| 3 | +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 4 | +# Also available under a BSD-style license. See LICENSE. |
| 5 | + |
| 6 | +import torch |
| 7 | + |
| 8 | +from torch_mlir_e2e_test.torchscript.framework import TestUtils |
| 9 | +from torch_mlir_e2e_test.torchscript.registry import register_test_case |
| 10 | +from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export |
| 11 | + |
| 12 | + |
| 13 | +# ============================================================================== |
| 14 | +# Global parameters |
| 15 | +NUM_SEGMENTS = 42 |
| 16 | +NUM_BINS = 5000 |
| 17 | +NUM_LOGITS = 5000 |
| 18 | + |
| 19 | +class HistogramBinningCalibrationByFeature(torch.nn.Module): |
| 20 | + def __init__(self): |
| 21 | + super().__init__() |
| 22 | + self._num_segments = NUM_SEGMENTS |
| 23 | + self._num_bins = NUM_BINS |
| 24 | + self._num_logits = NUM_LOGITS |
| 25 | + _num_interval = (self._num_segments + 1) * self._num_bins |
| 26 | + _lower_bound = 0 |
| 27 | + _upper_bound = 1 |
| 28 | + l, u = _lower_bound, _upper_bound |
| 29 | + w = (u - l) / self._num_bins |
| 30 | + self.step = w |
| 31 | + self.register_buffer("_boundaries", torch.arange(l + w, u - w / 2, w)) |
| 32 | + self.register_buffer( |
| 33 | + "_bin_num_examples", |
| 34 | + torch.empty([_num_interval], dtype=torch.float64).fill_(0.0), |
| 35 | + ) |
| 36 | + self.register_buffer( |
| 37 | + "_bin_num_positives", |
| 38 | + torch.empty([_num_interval], dtype=torch.float64).fill_(0.0), |
| 39 | + ) |
| 40 | + self.register_buffer("_bin_ids", torch.arange(_num_interval)) |
| 41 | + self.positive_weight = torch.tensor([0.4]) |
| 42 | + self.bin_ctr_in_use_after = 0 |
| 43 | + self.bin_ctr_weight_value = 0.9995 |
| 44 | + self.oneminusbin_ctr_weight_value = 0.0005 |
| 45 | + self._iteration = 0 |
| 46 | + |
| 47 | + @export |
| 48 | + @annotate_args([ |
| 49 | + None, |
| 50 | + ([-1], torch.int32, True), |
| 51 | + ([-1], torch.int32, True), |
| 52 | + ([-1], torch.float32, True), |
| 53 | + ]) |
| 54 | + def forward(self, segment_value, segment_lengths, logit): |
| 55 | + origin_prediction = torch.sigmoid( |
| 56 | + logit + torch.log(self.positive_weight)) |
| 57 | + dense_segment_value = torch.zeros(logit.numel(), dtype=torch.int32) |
| 58 | + validoffsets = torch.gt( |
| 59 | + segment_lengths[1:self._num_logits+1], segment_lengths[0:self._num_logits]) |
| 60 | + gathered_segment_values = ( |
| 61 | + segment_value[segment_lengths[0:self._num_logits].long()]+1).int() |
| 62 | + dense_segment_value = torch.where( |
| 63 | + validoffsets, gathered_segment_values, dense_segment_value) |
| 64 | + zeros = torch.empty_like( |
| 65 | + dense_segment_value, dtype=torch.int32).fill_(0) |
| 66 | + isnotvalid = torch.gt(dense_segment_value, self._num_segments) |
| 67 | + dense_segment_value = torch.where( |
| 68 | + isnotvalid, zeros, dense_segment_value) |
| 69 | + bin_ids_data = torch.ceil(origin_prediction/self.step)-1 |
| 70 | + bin_ids_data = bin_ids_data.long() |
| 71 | + curr_segment_value = dense_segment_value * self._num_bins |
| 72 | + bin_ids_data2 = bin_ids_data |
| 73 | + bin_ids_data = bin_ids_data + curr_segment_value |
| 74 | + curr_segment_value = self._bin_num_positives[bin_ids_data] |
| 75 | + curr_bin_num_examples = self._bin_num_examples[bin_ids_data] |
| 76 | + curr_segment_value = curr_segment_value / curr_bin_num_examples |
| 77 | + curr_segment_value = curr_segment_value.float() |
| 78 | + curr_segment_value = curr_segment_value * self.bin_ctr_weight_value + \ |
| 79 | + origin_prediction * self.oneminusbin_ctr_weight_value |
| 80 | + isvalid = torch.gt(curr_bin_num_examples, |
| 81 | + self.bin_ctr_in_use_after) |
| 82 | + calibrated_prediction_data = torch.where( |
| 83 | + isvalid, curr_segment_value, origin_prediction.float()) |
| 84 | + return calibrated_prediction_data, bin_ids_data |
| 85 | + |
| 86 | + |
| 87 | +@register_test_case(module_factory=lambda: HistogramBinningCalibrationByFeature()) |
| 88 | +def HBC_basic(module, tu: TestUtils): |
| 89 | + logits = torch.rand(NUM_LOGITS, dtype=torch.float) |
| 90 | + segment_lengths: Tensor = torch.randint( |
| 91 | + 0, 2, (NUM_LOGITS,), dtype=torch.int) |
| 92 | + segment_offsets: Tensor = torch.cumsum(segment_lengths, 0) |
| 93 | + segment_offsets: Tensor = torch.cat( |
| 94 | + (torch.tensor([0]), segment_offsets), 0) |
| 95 | + num_values: int = int(torch.sum(segment_lengths).item()) |
| 96 | + segment_values: Tensor = torch.randint( |
| 97 | + 0, |
| 98 | + NUM_SEGMENTS, |
| 99 | + (num_values,), |
| 100 | + ) |
| 101 | + segment_values = torch.cat( |
| 102 | + (segment_values, torch.zeros(NUM_LOGITS-segment_values.numel())), 0) |
| 103 | + module.forward(segment_values.int(), segment_offsets.int(), logits) |
| 104 | + #input shape (5000, 5001, 5000) |
0 commit comments