Skip to content

Commit 6f3d599

Browse files
sdaultonfacebook-github-bot
authored andcommitted
Preserve order in ChainedOutcomeTransform (#440)
Summary: Pull Request resolved: #440 Using OrderedDict, perserves the order. Otherwise, ModuleDict orders the transforms alphanumerically by key. Reviewed By: Balandat Differential Revision: D21477968 fbshipit-source-id: aa85d450ccb9b8baa651f2a82df526e91e4daa69
1 parent 4bc0352 commit 6f3d599

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

botorch/models/transforms/outcome.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from __future__ import annotations
88

99
from abc import ABC, abstractmethod
10+
from collections import OrderedDict
1011
from typing import List, Optional, Tuple
1112

1213
import torch
@@ -87,7 +88,7 @@ def __init__(self, **transforms: OutcomeTransform) -> None:
8788
kwargs are used as the keys for accessing the individual
8889
transforms on the module.
8990
"""
90-
super().__init__(transforms)
91+
super().__init__(OrderedDict(transforms))
9192

9293
def forward(
9394
self, Y: Tensor, Yvar: Optional[Tensor] = None

test/models/transforms/test_outcome.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -328,11 +328,11 @@ def test_chained_outcome_transform(self):
328328
# test init
329329
tf1 = Log()
330330
tf2 = Standardize(m=m, batch_shape=batch_shape)
331-
tf = ChainedOutcomeTransform(log=tf1, standardize=tf2)
331+
tf = ChainedOutcomeTransform(b=tf1, a=tf2)
332332
self.assertTrue(tf.training)
333-
self.assertEqual(sorted(tf.keys()), ["log", "standardize"])
334-
self.assertEqual(tf["log"], tf1)
335-
self.assertEqual(tf["standardize"], tf2)
333+
self.assertEqual(list(tf.keys()), ["b", "a"])
334+
self.assertEqual(tf["b"], tf1)
335+
self.assertEqual(tf["a"], tf2)
336336

337337
# make copies for validation below
338338
tf1_, tf2_ = deepcopy(tf1), deepcopy(tf2)

0 commit comments

Comments
 (0)