Skip to content

Commit ef7fa96

Browse files
d4l3kpytorchmergebot
authored andcommitted
dist: add list_keys to Store API (pytorch#167883)
This adds a `list` Store API and implements it for all backends. This is intended to be used for debugging and will allow inspecting all keys in a store locally as well as remotely in the case of TCPStore. Test plan: ``` pytest test/distributed/test_store.py ``` Pull Request resolved: pytorch#167883 Approved by: https://github.com/fduwjj
1 parent 7ffeb34 commit ef7fa96

File tree

15 files changed

+128
-0
lines changed

15 files changed

+128
-0
lines changed

test/distributed/test_store.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,14 @@ def test_clone(self):
253253
a.set("foo", "bar")
254254
self.assertEqual(b.get("foo"), b"bar")
255255

256+
def test_list_keys(self):
257+
a = self._create_store()
258+
a.set("foo", "bar")
259+
a.set("baz", "qux")
260+
keys = a.list_keys()
261+
self.assertIn("foo", keys)
262+
self.assertIn("baz", keys)
263+
256264
# This is the number of keys used in test_set_get. Adding this as a class
257265
# property instead of hardcoding in the test since some Store
258266
# implementations will have differing number of keys. In the base case,

torch/_C/_distributed_c10d.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ class Store:
215215
def queue_pop(self, key: str, block: bool = True) -> bytes: ...
216216
def queue_push(self, key: str, value: Union[bytes, str]) -> None: ...
217217
def queue_len(self, key: str) -> int: ...
218+
def list_keys(self) -> list[str]: ...
218219

219220
class FileStore(Store):
220221
def __init__(self, path: str, numWorkers: int = ...) -> None: ...

torch/csrc/distributed/c10d/FileStore.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,4 +492,17 @@ void FileStore::wait(
492492
}
493493
}
494494

495+
std::vector<std::string> FileStore::listKeys() {
496+
std::unique_lock<std::mutex> l(activeFileOpLock_);
497+
File file(path_, O_RDONLY, timeout_);
498+
auto lock = file.lockShared();
499+
pos_ = refresh(file, pos_, cache_, deletePrefix_);
500+
std::vector<std::string> keys;
501+
keys.reserve(cache_.size());
502+
for (const auto& kv : cache_) {
503+
keys.push_back(kv.first.substr(regularPrefix_.size()));
504+
}
505+
return keys;
506+
}
507+
495508
} // namespace c10d

torch/csrc/distributed/c10d/FileStore.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class TORCH_API FileStore : public Store {
4545
return path_;
4646
}
4747

48+
std::vector<std::string> listKeys() override;
49+
4850
protected:
4951
int64_t addHelper(const std::string& key, int64_t i);
5052

torch/csrc/distributed/c10d/HashStore.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,4 +217,14 @@ int64_t HashStore::queueLen(const std::string& key) {
217217
return static_cast<int64_t>(it->second.size());
218218
}
219219

220+
std::vector<std::string> HashStore::listKeys() {
221+
std::unique_lock<std::mutex> lock(m_);
222+
std::vector<std::string> keys;
223+
keys.reserve(map_.size());
224+
for (const auto& kv : map_) {
225+
keys.push_back(kv.first);
226+
}
227+
return keys;
228+
}
229+
220230
} // namespace c10d

torch/csrc/distributed/c10d/HashStore.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class TORCH_API HashStore : public Store {
5959

6060
int64_t queueLen(const std::string& key) override;
6161

62+
std::vector<std::string> listKeys() override;
63+
6264
protected:
6365
bool checkLocked(
6466
const std::unique_lock<std::mutex>& lock,

torch/csrc/distributed/c10d/PrefixStore.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,4 +146,18 @@ c10::intrusive_ptr<Store> PrefixStore::getUnderlyingNonPrefixStore() {
146146
return store;
147147
}
148148

149+
std::vector<std::string> PrefixStore::listKeys() {
150+
auto keys = store_->listKeys();
151+
std::vector<std::string> filteredKeys;
152+
filteredKeys.reserve(keys.size());
153+
154+
for (auto& key : keys) {
155+
if (key.find(prefix_) == 0) {
156+
key = key.substr(prefix_.size() + 1);
157+
filteredKeys.push_back(std::move(key));
158+
}
159+
}
160+
return filteredKeys;
161+
}
162+
149163
} // namespace c10d

torch/csrc/distributed/c10d/PrefixStore.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class TORCH_API PrefixStore : public Store {
6464
// Recursively to fetch the store before layers of wrapping with PrefixStore.
6565
c10::intrusive_ptr<Store> getUnderlyingNonPrefixStore();
6666

67+
std::vector<std::string> listKeys() override;
68+
6769
protected:
6870
std::string prefix_;
6971
c10::intrusive_ptr<Store> store_;

torch/csrc/distributed/c10d/Store.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ class TORCH_API Store : public torch::CustomClassHolder {
114114
C10_THROW_ERROR(NotImplementedError, "queue support is not implemented.");
115115
}
116116

117+
virtual std::vector<std::string> listKeys() {
118+
C10_THROW_ERROR(
119+
NotImplementedError, "listKeys support is not implemented.");
120+
}
121+
117122
protected:
118123
std::chrono::milliseconds timeout_;
119124
};

torch/csrc/distributed/c10d/TCPStore.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,30 @@ int64_t TCPStore::queueLen(const std::string& key) {
723723
return client_->receiveValue<int64_t>();
724724
}
725725

726+
std::vector<std::string> TCPStore::listKeys() {
727+
STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__list);
728+
729+
const std::lock_guard<std::mutex> lock(activeOpLock_);
730+
731+
detail::SendBuffer buffer(*client_, detail::QueryType::LIST_KEYS);
732+
buffer.flush();
733+
734+
auto numKeys = client_->receiveValue<int64_t>();
735+
std::vector<std::string> keys;
736+
keys.reserve(numKeys);
737+
for (auto i = 0; i < numKeys; ++i) {
738+
auto bits = client_->receiveBits();
739+
std::string str(bits.begin(), bits.end());
740+
if (str.find(keyPrefix_) == 0) {
741+
str = str.substr(keyPrefix_.size());
742+
} else {
743+
continue;
744+
}
745+
keys.emplace_back(str);
746+
}
747+
return keys;
748+
}
749+
726750
bool TCPStore::hasExtendedApi() const {
727751
return true;
728752
}

0 commit comments

Comments
 (0)