Skip to content

Commit 17a4843

Browse files
Adding an e2e test for histogram binning calibration
1 parent cadea67 commit 17a4843

File tree

2 files changed

+105
-0
lines changed

2 files changed

+105
-0
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
# Also available under a BSD-style license. See LICENSE.
5+
6+
import torch
7+
8+
from torch_mlir_e2e_test.torchscript.framework import TestUtils
9+
from torch_mlir_e2e_test.torchscript.registry import register_test_case
10+
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
11+
12+
13+
# ==============================================================================
14+
# Global parameters
15+
NUM_SEGMENTS = 42
16+
NUM_BINS = 5000
17+
NUM_LOGITS = 5000
18+
19+
class HistogramBinningCalibrationByFeature(torch.nn.Module):
20+
def __init__(self):
21+
super().__init__()
22+
self._num_segments = NUM_SEGMENTS
23+
self._num_bins = NUM_BINS
24+
self._num_logits = NUM_LOGITS
25+
_num_interval = (self._num_segments + 1) * self._num_bins
26+
_lower_bound = 0
27+
_upper_bound = 1
28+
l, u = _lower_bound, _upper_bound
29+
w = (u - l) / self._num_bins
30+
self.step = w
31+
self.register_buffer("_boundaries", torch.arange(l + w, u - w / 2, w))
32+
self.register_buffer(
33+
"_bin_num_examples",
34+
torch.empty([_num_interval], dtype=torch.float64).fill_(0.0),
35+
)
36+
self.register_buffer(
37+
"_bin_num_positives",
38+
torch.empty([_num_interval], dtype=torch.float64).fill_(0.0),
39+
)
40+
self.register_buffer("_bin_ids", torch.arange(_num_interval))
41+
self.positive_weight = torch.tensor([0.4])
42+
self.bin_ctr_in_use_after = 0
43+
self.bin_ctr_weight_value = 0.9995
44+
self.oneminusbin_ctr_weight_value = 0.0005
45+
self._iteration = 0
46+
47+
@export
48+
@annotate_args([
49+
None,
50+
([-1], torch.int32, True),
51+
([-1], torch.int32, True),
52+
([-1], torch.float32, True),
53+
])
54+
def forward(self, segment_value, segment_lengths, logit):
55+
origin_prediction = torch.sigmoid(
56+
logit + torch.log(self.positive_weight))
57+
dense_segment_value = torch.zeros(logit.numel(), dtype=torch.int32)
58+
validoffsets = torch.gt(
59+
segment_lengths[1:self._num_logits+1], segment_lengths[0:self._num_logits])
60+
gathered_segment_values = (
61+
segment_value[segment_lengths[0:self._num_logits].long()]+1).int()
62+
dense_segment_value = torch.where(
63+
validoffsets, gathered_segment_values, dense_segment_value)
64+
zeros = torch.empty_like(
65+
dense_segment_value, dtype=torch.int32).fill_(0)
66+
isnotvalid = torch.gt(dense_segment_value, self._num_segments)
67+
dense_segment_value = torch.where(
68+
isnotvalid, zeros, dense_segment_value)
69+
bin_ids_data = torch.ceil(origin_prediction/self.step)-1
70+
bin_ids_data = bin_ids_data.long()
71+
curr_segment_value = dense_segment_value * self._num_bins
72+
bin_ids_data2 = bin_ids_data
73+
bin_ids_data = bin_ids_data + curr_segment_value
74+
curr_segment_value = self._bin_num_positives[bin_ids_data]
75+
curr_bin_num_examples = self._bin_num_examples[bin_ids_data]
76+
curr_segment_value = curr_segment_value / curr_bin_num_examples
77+
curr_segment_value = curr_segment_value.float()
78+
curr_segment_value = curr_segment_value * self.bin_ctr_weight_value + \
79+
origin_prediction * self.oneminusbin_ctr_weight_value
80+
isvalid = torch.gt(curr_bin_num_examples,
81+
self.bin_ctr_in_use_after)
82+
calibrated_prediction_data = torch.where(
83+
isvalid, curr_segment_value, origin_prediction.float())
84+
return calibrated_prediction_data, bin_ids_data
85+
86+
87+
@register_test_case(module_factory=lambda: HistogramBinningCalibrationByFeature())
88+
def HBC_basic(module, tu: TestUtils):
89+
logits = torch.rand(NUM_LOGITS, dtype=torch.float)
90+
segment_lengths: Tensor = torch.randint(
91+
0, 2, (NUM_LOGITS,), dtype=torch.int)
92+
segment_offsets: Tensor = torch.cumsum(segment_lengths, 0)
93+
segment_offsets: Tensor = torch.cat(
94+
(torch.tensor([0]), segment_offsets), 0)
95+
num_values: int = int(torch.sum(segment_lengths).item())
96+
segment_values: Tensor = torch.randint(
97+
0,
98+
NUM_SEGMENTS,
99+
(num_values,),
100+
)
101+
segment_values = torch.cat(
102+
(segment_values, torch.zeros(NUM_LOGITS-segment_values.numel())), 0)
103+
module.forward(segment_values.int(), segment_offsets.int(), logits)
104+
#input shape (5000, 5001, 5000)

e2e_testing/torchscript/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from . import arange
4949
from . import constant_alloc
5050
from . import threshold
51+
from . import histogram_binning_calibration
5152

5253
def _get_argparse():
5354
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']

0 commit comments

Comments
 (0)