Skip to content

Commit faa5f90

Browse files
committed
Merge pull request google-deepmind#1054 from maichmueller:feat/infostate_binding
PiperOrigin-RevId: 859599054 Change-Id: I75fcf8f11af2d89638a3c130a533398784583fcc
2 parents bbc9c1a + 9025b70 commit faa5f90

File tree

10 files changed

+901
-20
lines changed

10 files changed

+901
-20
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ open_spiel/games/universal_poker/double_dummy_solver/
3838
open_spiel/games/hanabi/hanabi-learning-environment/
3939
/open_spiel/pybind11_abseil/
4040
pybind11/
41+
!open_spiel/python/pybind11
4142

4243
# Install artifacts
4344
download_cache/

open_spiel/algorithms/infostate_tree.cc

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -214,15 +214,16 @@ void InfostateTree::CollectNodesAtDepth(InfostateNode* node, size_t depth) {
214214
CollectNodesAtDepth(child, depth + 1);
215215
}
216216

217-
std::ostream& InfostateTree::operator<<(std::ostream& os) const {
218-
return os << "Infostate tree for player " << acting_player_ << ".\n"
219-
<< "Tree height: " << tree_height_ << '\n'
220-
<< "Root branching: " << root_branching_factor() << '\n'
221-
<< "Number of decision infostate nodes: " << num_decisions() << '\n'
222-
<< "Number of sequences: " << num_sequences() << '\n'
223-
<< "Number of leaves: " << num_leaves() << '\n'
217+
std::ostream& operator<<(std::ostream& os, const InfostateTree& tree) {
218+
return os << "Infostate tree for player " << tree.acting_player_ << ".\n"
219+
<< "Tree height: " << tree.tree_height_ << '\n'
220+
<< "Root branching: " << tree.root_branching_factor() << '\n'
221+
<< "Number of decision infostate nodes: " << tree.num_decisions()
222+
<< '\n'
223+
<< "Number of sequences: " << tree.num_sequences() << '\n'
224+
<< "Number of leaves: " << tree.num_leaves() << '\n'
224225
<< "Tree certificate: " << '\n'
225-
<< root().MakeCertificate() << '\n';
226+
<< tree.root().MakeCertificate() << '\n';
226227
}
227228

228229
std::unique_ptr<InfostateNode> InfostateTree::MakeNode(
@@ -536,7 +537,7 @@ absl::optional<DecisionId> InfostateTree::DecisionIdForSequence(
536537
}
537538
}
538539
absl::optional<InfostateNode*> InfostateTree::DecisionForSequence(
539-
const SequenceId& sequence_id) {
540+
const SequenceId& sequence_id) const {
540541
SPIEL_DCHECK_TRUE(sequence_id.BelongsToTree(this));
541542
InfostateNode* node = sequences_.at(sequence_id.id());
542543
SPIEL_DCHECK_TRUE(node);
@@ -646,7 +647,7 @@ std::pair<size_t, size_t> InfostateTree::CollectStartEndSequenceIds(
646647
}
647648

648649
std::pair<double, SfStrategy> InfostateTree::BestResponse(
649-
TreeplexVector<double>&& gradient) const {
650+
TreeplexVector<double> gradient) const {
650651
SPIEL_CHECK_EQ(this, gradient.tree());
651652
SPIEL_CHECK_EQ(num_sequences(), gradient.size());
652653
SfStrategy response(this);
@@ -698,7 +699,7 @@ std::pair<double, SfStrategy> InfostateTree::BestResponse(
698699
return {gradient[empty_sequence()], response};
699700
}
700701

701-
double InfostateTree::BestResponseValue(LeafVector<double>&& gradient) const {
702+
double InfostateTree::BestResponseValue(LeafVector<double> gradient) const {
702703
// Loop over all heights.
703704
for (int d = tree_height_ - 1; d >= 0; d--) {
704705
int left_offset = 0;

open_spiel/algorithms/infostate_tree.h

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ class RangeIterator {
240240
bool operator!=(const RangeIterator& other) const {
241241
return id_ != other.id_ || tree_ != other.tree_;
242242
}
243+
bool operator==(const RangeIterator& other) const {
244+
return !(this->operator!=(other));
245+
}
243246
Id operator*() { return Id(id_, tree_); }
244247
};
245248

@@ -289,7 +292,7 @@ std::shared_ptr<InfostateTree> MakeInfostateTree(
289292
const std::vector<InfostateNode*>& start_nodes,
290293
bool store_world_states = false, int max_move_ahead_limit = 1000);
291294

292-
class InfostateTree final {
295+
class InfostateTree final : public std::enable_shared_from_this<InfostateTree> {
293296
// Note that only MakeInfostateTree is allowed to call the constructor
294297
// to ensure the trees are always allocated on heap. We do this so that all
295298
// the collected pointers are valid throughout the tree's lifetime even if
@@ -310,6 +313,12 @@ class InfostateTree final {
310313
const std::vector<const InfostateNode*>&, bool, int);
311314

312315
public:
316+
// -- gain shared ownership of the allocated infostate object
317+
std::shared_ptr<InfostateTree> shared_ptr() { return shared_from_this(); }
318+
std::shared_ptr<const InfostateTree> shared_ptr() const {
319+
return shared_from_this();
320+
}
321+
313322
// -- Root accessors ---------------------------------------------------------
314323
const InfostateNode& root() const { return *root_; }
315324
InfostateNode* mutable_root() { return root_.get(); }
@@ -349,7 +358,8 @@ class InfostateTree final {
349358
// Returns `None` if the sequence is the empty sequence.
350359
absl::optional<DecisionId> DecisionIdForSequence(const SequenceId&) const;
351360
// Returns `None` if the sequence is the empty sequence.
352-
absl::optional<InfostateNode*> DecisionForSequence(const SequenceId&);
361+
absl::optional<InfostateNode*> DecisionForSequence(
362+
const SequenceId& sequence_id) const;
353363
// Returns whether the sequence ends with the last action the player can make.
354364
bool IsLeafSequence(const SequenceId&) const;
355365

@@ -392,13 +402,13 @@ class InfostateTree final {
392402
// Compute best response and value based on gradient from opponents.
393403
// This consumes the gradient vector, as it is used to compute the value.
394404
std::pair<double, SfStrategy> BestResponse(
395-
TreeplexVector<double>&& gradient) const;
405+
TreeplexVector<double> gradient) const;
396406
// Compute best response value based on gradient from opponents over leaves.
397407
// This consumes the gradient vector, as it is used to compute the value.
398-
double BestResponseValue(LeafVector<double>&& gradient) const;
408+
double BestResponseValue(LeafVector<double> gradient) const;
399409

400410
// -- For debugging ----------------------------------------------------------
401-
std::ostream& operator<<(std::ostream& os) const;
411+
friend std::ostream& operator<<(std::ostream& os, const InfostateTree& tree);
402412

403413
private:
404414
const Player acting_player_;

open_spiel/python/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ endif()
2626
# List of all Python bindings to add to pyspiel.
2727
include_directories (../pybind11_abseil ../../pybind11/include)
2828
set(PYTHON_BINDINGS ${PYTHON_BINDINGS}
29+
pybind11/algorithms_infostate_tree.cc
30+
pybind11/algorithms_infostate_tree.tcc
31+
pybind11/algorithms_infostate_tree.h
2932
pybind11/algorithms_corr_dist.cc
3033
pybind11/algorithms_corr_dist.h
3134
pybind11/algorithms_trajectories.cc
@@ -197,6 +200,7 @@ set(PYTHON_TESTS ${PYTHON_TESTS}
197200
algorithms/generate_playthrough_test.py
198201
algorithms/get_all_states_test.py
199202
algorithms/ismcts_agent_test.py
203+
algorithms/infostate_tree_test.py
200204
algorithms/mcts_agent_test.py
201205
algorithms/mcts_test.py
202206
algorithms/minimax_test.py

0 commit comments

Comments
 (0)